github.com/keysonZZZ/kmg@v0.0.0-20151121023212-05317bfd7d39/third/kmgRadius/packet.go (about) 1 package kmgRadius 2 3 import ( 4 "crypto" 5 "crypto/hmac" 6 _ "crypto/md5" 7 "crypto/rand" 8 "encoding/binary" 9 "fmt" 10 "net" 11 "strconv" 12 13 "github.com/bronze1man/kmg/third/kmgRadius/eap" 14 ) 15 16 var ErrMessageAuthenticatorCheckFail = fmt.Errorf("RADIUS Response-Authenticator verification failed") 17 18 const maxPacketLength = 4096 // rfc2058 Page 9 Length 19 20 type Packet struct { 21 Secret []byte 22 Code Code 23 Identifier uint8 24 Authenticator [16]byte //对应的Request请求里面的Authenticator. 25 AVPs []AVP 26 } 27 28 func (p *Packet) Copy() *Packet { 29 outP := &Packet{ 30 Secret: p.Secret, 31 Code: p.Code, 32 Identifier: p.Identifier, 33 Authenticator: p.Authenticator, //这个应该是拷贝 34 } 35 outP.AVPs = make([]AVP, len(p.AVPs)) 36 for i := range p.AVPs { 37 outP.AVPs[i] = p.AVPs[i].Copy() 38 } 39 return outP 40 } 41 42 //此方法保证不修改包的内容 43 func (p *Packet) Encode() (b []byte, err error) { 44 p = p.Copy() 45 p.SetAVP(&BinaryAVP{ 46 Type: AVPTypeMessageAuthenticator, 47 Value: make([]byte, 16), 48 }) 49 if p.Code == CodeAccessRequest { 50 _, err := rand.Read(p.Authenticator[:]) 51 if err != nil { 52 return nil, err 53 } 54 } 55 //TODO request的时候重新计算密码 56 b, err = p.encodeNoHash() 57 if err != nil { 58 return 59 } 60 //计算Message-Authenticator这个AVP的值,特殊化处理,Message-Authenticator这个AVP被放在最后面 61 hasher := hmac.New(crypto.MD5.New, []byte(p.Secret)) 62 hasher.Write(b) 63 copy(b[len(b)-16:len(b)], hasher.Sum(nil)) 64 65 // fix up the authenticator 66 // handle request and response stuff. 67 // here only handle response part. 68 switch p.Code { 69 case CodeAccessRequest: 70 case CodeAccessAccept, CodeAccessReject, CodeAccessChallenge, 71 CodeAccountingRequest, CodeAccountingResponse: 72 //rfc2865 page 15 Response Authenticator 73 //rfc2866 page 6 Response Authenticator 74 //rfc2866 page 6 Request Authenticator 75 hasher := crypto.Hash(crypto.MD5).New() 76 hasher.Write(b) 77 hasher.Write([]byte(p.Secret)) 78 copy(b[4:20], hasher.Sum(nil)) //返回值再把Authenticator写回去 79 default: 80 return nil, fmt.Errorf("not handle p.Code %d", p.Code) 81 } 82 83 return b, err 84 } 85 86 func (p *Packet) encodeNoHash() (b []byte, err error) { 87 b = make([]byte, maxPacketLength) 88 b[0] = uint8(p.Code) 89 b[1] = uint8(p.Identifier) 90 copy(b[4:20], p.Authenticator[:]) 91 written := 20 92 bb := b[20:] 93 for i, _ := range p.AVPs { 94 bb1, err := p.AVPs[i].Encode() 95 if err != nil { 96 return nil, err 97 } 98 written += len(bb1) 99 if written > maxPacketLength { 100 return nil, fmt.Errorf("[Packet.encodeNoHash] packet too large written[%d]>maxPacketLength[%d]", 101 written, maxPacketLength) 102 } 103 copy(bb, bb1) 104 bb = bb[len(bb1):] 105 } 106 binary.BigEndian.PutUint16(b[2:4], uint16(written)) 107 return b[:written], nil 108 } 109 110 //get one avp 111 func (p *Packet) GetAVP(attrType AVPType) AVP { 112 for i := range p.AVPs { 113 if p.AVPs[i].GetType() == attrType { 114 return p.AVPs[i] 115 } 116 } 117 return nil 118 } 119 120 //set one avp,remove all other same type 121 func (p *Packet) SetAVP(avp AVP) { 122 p.DeleteOneType(avp.GetType()) 123 p.AddAVP(avp) 124 } 125 126 func (p *Packet) AddAVP(avp AVP) { 127 p.AVPs = append(p.AVPs, avp) 128 } 129 130 func (p *Packet) GetVsa(typ VendorType) VSA { 131 for i := range p.AVPs { 132 if p.AVPs[i].GetType() != AVPTypeVendorSpecific { 133 continue 134 } 135 vsa, ok := p.AVPs[i].(*VendorSpecificAVP) 136 if !ok { 137 continue //允许使用binaryAVP代表一个AVPTypeVendorSpecific 138 } 139 if vsa.Value.GetType() != typ { 140 continue 141 } 142 return vsa.Value 143 } 144 return nil 145 } 146 147 //删除一个AVP 148 /* 149 func (p *Packet) DeleteAVP(avp AVP) { 150 for i := range p.AVPs { 151 if &(p.AVPs[i]) == avp { 152 for j := i; j < len(p.AVPs)-1; j++ { 153 p.AVPs[j] = p.AVPs[j+1] 154 } 155 p.AVPs = p.AVPs[:len(p.AVPs)-1] 156 break 157 } 158 } 159 return 160 } 161 */ 162 163 //delete all avps with this type 164 func (p *Packet) DeleteOneType(attrType AVPType) { 165 for i := 0; i < len(p.AVPs); i++ { 166 if p.AVPs[i].GetType() == attrType { 167 for j := i; j < len(p.AVPs)-1; j++ { 168 p.AVPs[j] = p.AVPs[j+1] 169 } 170 p.AVPs = p.AVPs[:len(p.AVPs)-1] 171 i-- 172 break 173 } 174 } 175 return 176 } 177 178 func (p *Packet) Reply() *Packet { 179 pac := new(Packet) 180 pac.Authenticator = p.Authenticator 181 pac.Identifier = p.Identifier 182 pac.Secret = p.Secret 183 state := p.GetState() 184 if len(state) > 0 { 185 pac.SetState(state) 186 } 187 return pac 188 } 189 190 func (p *Packet) Send(c net.PacketConn, addr net.Addr) error { 191 buf, err := p.Encode() 192 if err != nil { 193 return err 194 } 195 196 _, err = c.WriteTo(buf, addr) 197 return err 198 } 199 200 // 这个只能解密各种Request 201 func DecodeRequestPacket(Secret []byte, buf []byte) (p *Packet, err error) { 202 p = &Packet{Secret: Secret} 203 p.Code = Code(buf[0]) 204 p.Identifier = buf[1] 205 copy(p.Authenticator[:], buf[4:20]) 206 //read attributes 207 b := buf[20:] 208 for { 209 if len(b) == 0 { 210 break 211 } 212 if len(b) < 2 { 213 return nil, fmt.Errorf("[radius.DecodePacket] unexcept EOF") 214 } 215 length := uint8(b[1]) 216 if int(length) > len(b) { 217 return nil, fmt.Errorf("[radius.DecodePacket] invalid avp length len:%d len(b):%d", length, len(b)) 218 } 219 avp, err := avpDecode(p, b[:length]) 220 if err != nil { 221 return nil, err 222 } 223 p.AVPs = append(p.AVPs, avp) 224 b = b[length:] 225 } 226 //验证Message-Authenticator,并且通过测试验证此处算法是正确的 227 //此处不修改Message-Authenticator的值 228 err = p.checkMessageAuthenticator() 229 if err != nil { 230 return p, err 231 } 232 return p, nil 233 } 234 235 // 解密response包 236 func DecodeResponsePacket(Secret []byte, buf []byte, RequestAuthenticator [16]byte) (p *Packet, err error) { 237 p = &Packet{ 238 Secret: Secret, 239 Authenticator: RequestAuthenticator, 240 } 241 p.Code = Code(buf[0]) 242 p.Identifier = buf[1] 243 //read attributes 244 b := buf[20:] 245 for { 246 if len(b) == 0 { 247 break 248 } 249 if len(b) < 2 { 250 return nil, fmt.Errorf("[radius.DecodePacket] unexcept EOF") 251 } 252 length := uint8(b[1]) 253 if int(length) > len(b) { 254 return nil, fmt.Errorf("[radius.DecodePacket] invalid avp length len:%d len(b):%d", length, len(b)) 255 } 256 avp, err := avpDecode(p, b[:length]) 257 if err != nil { 258 return nil, err 259 } 260 p.AVPs = append(p.AVPs, avp) 261 b = b[length:] 262 } 263 //验证Message-Authenticator,并且通过测试验证此处算法是正确的 264 //此处不修改Message-Authenticator的值 265 err = p.checkMessageAuthenticator() 266 if err != nil { 267 return p, err 268 } 269 return p, nil 270 } 271 272 //如果没有MessageAuthenticator也算通过 273 func (p *Packet) checkMessageAuthenticator() (err error) { 274 AuthenticatorI := p.GetAVP(AVPTypeMessageAuthenticator) 275 if AuthenticatorI == nil { 276 return nil 277 } 278 Authenticator := AuthenticatorI.(*BinaryAVP) 279 AuthenticatorValue := Authenticator.Value 280 defer func() { Authenticator.Value = AuthenticatorValue }() 281 Authenticator.Value = make([]byte, 16) 282 content, err := p.encodeNoHash() 283 if err != nil { 284 return err 285 } 286 hasher := hmac.New(crypto.MD5.New, []byte(p.Secret)) 287 hasher.Write(content) 288 if !hmac.Equal(hasher.Sum(nil), AuthenticatorValue) { 289 return ErrMessageAuthenticatorCheckFail 290 } 291 return nil 292 } 293 294 func (p *Packet) String() string { 295 s := "Code: " + p.Code.String() + "\n" + 296 "Identifier: " + strconv.Itoa(int(p.Identifier)) + "\n" + 297 "Authenticator: " + fmt.Sprintf("%#v", p.Authenticator) + "\n" 298 for _, avp := range p.AVPs { 299 s += avp.String() + "\n" 300 } 301 return s 302 } 303 304 //转成字符串map,便于进行log(序列化?),只有实际信息,已经把加密的东西剔除掉了 305 func (p *Packet) ToStringMap() map[string]string { 306 out := make(map[string]string, len(p.AVPs)) 307 for _, avp := range p.AVPs { 308 if avp.GetType() == AVPTypeMessageAuthenticator { 309 continue 310 } 311 out[avp.GetType().String()] = avp.ValueAsString() 312 } 313 out["Code"] = p.Code.String() 314 out["Identifier"] = strconv.Itoa(int(p.Identifier)) 315 return out 316 } 317 318 func (p *Packet) GetUsername() (username string) { 319 avp := p.GetAVP(AVPTypeUserName) 320 if avp == nil { 321 return "" 322 } 323 return avp.(*StringAVP).Value 324 } 325 func (p *Packet) GetPassword() (password string) { 326 avp := p.GetAVP(AVPTypeUserPassword) 327 if avp == nil { 328 return "" 329 } 330 return avp.(*PasswordAVP).Value 331 } 332 333 func (p *Packet) GetNasIpAddress() (ip net.IP) { 334 avp := p.GetAVP(AVPTypeNASIPAddress) 335 if avp == nil { 336 return nil 337 } 338 return avp.(*IpAVP).Value 339 } 340 341 func (p *Packet) GetAcctStatusType() AcctStatusTypeEnum { 342 avp := p.GetAVP(AVPTypeAcctStatusType) 343 if avp == nil { 344 return AcctStatusTypeEnum(0) 345 } 346 return avp.(*Uint32EnumAVP).Value.(AcctStatusTypeEnum) 347 } 348 349 func (p *Packet) GetAcctSessionId() string { 350 avp := p.GetAVP(AVPTypeAcctSessionId) 351 if avp == nil { 352 return "" 353 } 354 return avp.(*StringAVP).Value 355 } 356 357 func (p *Packet) GetAcctTotalOutputOctets() uint64 { 358 out := uint64(0) 359 avp := p.GetAVP(AVPTypeAcctOutputOctets) 360 if avp != nil { 361 out += uint64(avp.(*Uint32AVP).Value) 362 } 363 avp = p.GetAVP(AVPTypeAcctOutputGigawords) 364 if avp != nil { 365 out += uint64(avp.(*Uint32AVP).Value) * (2 ^ 32) 366 } 367 return out 368 } 369 370 func (p *Packet) GetAcctTotalInputOctets() uint64 { 371 out := uint64(0) 372 avp := p.GetAVP(AVPTypeAcctInputOctets) 373 if avp != nil { 374 out += uint64(avp.(*Uint32AVP).Value) 375 } 376 avp = p.GetAVP(AVPTypeAcctInputGigawords) 377 if avp != nil { 378 out += uint64(avp.(*Uint32AVP).Value) * (2 ^ 32) 379 } 380 return out 381 } 382 383 // it is ike_id in strongswan client 384 func (p *Packet) GetNASPort() uint32 { 385 avp := p.GetAVP(AVPTypeNASPort) 386 if avp == nil { 387 return 0 388 } 389 return avp.(*Uint32AVP).Value 390 } 391 392 func (p *Packet) GetNASIdentifier() string { 393 avp := p.GetAVP(AVPTypeNASIdentifier) 394 if avp == nil { 395 return "" 396 } 397 return avp.(*StringAVP).Value 398 } 399 400 func (p *Packet) GetEAPMessage() eap.Packet { 401 avp := p.GetAVP(AVPTypeEAPMessage) 402 if avp == nil { 403 return nil 404 } 405 return avp.(*EapAVP).Value 406 } 407 408 func (p *Packet) GetState() []byte { 409 avp := p.GetAVP(AVPTypeState) 410 if avp == nil { 411 return nil 412 } 413 return avp.GetValue().([]byte) 414 } 415 416 func (p *Packet) SetState(state []byte) { 417 p.SetAVP(&BinaryAVP{ 418 Type: AVPTypeState, 419 Value: state, 420 }) 421 } 422 423 func (p *Packet) SetAcctInterimInterval(second int) { 424 p.SetAVP(&Uint32AVP{ 425 Type: AVPTypeAcctInterimInterval, 426 Value: uint32(second), 427 }) 428 } 429 430 func (p *Packet) GetAcctSessionTime() uint32 { 431 avp := p.GetAVP(AVPTypeAcctSessionTime) 432 if avp == nil { 433 return 0 434 } 435 return avp.GetValue().(uint32) 436 }