github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/prog/checksum.go (about) 1 // Copyright 2017 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 package prog 5 6 import ( 7 "fmt" 8 ) 9 10 type CsumChunkKind int 11 12 const ( 13 CsumChunkArg CsumChunkKind = iota 14 CsumChunkConst 15 ) 16 17 type CsumInfo struct { 18 Kind CsumKind 19 Chunks []CsumChunk 20 } 21 22 type CsumChunk struct { 23 Kind CsumChunkKind 24 Arg Arg // for CsumChunkArg 25 Value uint64 // for CsumChunkConst 26 Size uint64 // for CsumChunkConst 27 } 28 29 func calcChecksumsCall(c *Call) (map[Arg]CsumInfo, map[Arg]struct{}) { 30 var inetCsumFields, pseudoCsumFields []Arg 31 32 // Find all csum fields. 33 ForeachArg(c, func(arg Arg, _ *ArgCtx) { 34 if typ, ok := arg.Type().(*CsumType); ok { 35 switch typ.Kind { 36 case CsumInet: 37 inetCsumFields = append(inetCsumFields, arg) 38 case CsumPseudo: 39 pseudoCsumFields = append(pseudoCsumFields, arg) 40 default: 41 panic(fmt.Sprintf("unknown csum kind %v", typ.Kind)) 42 } 43 } 44 }) 45 46 if len(inetCsumFields) == 0 && len(pseudoCsumFields) == 0 { 47 return nil, nil 48 } 49 50 // Build map of each field to its parent struct. 51 parentsMap := make(map[Arg]Arg) 52 ForeachArg(c, func(arg Arg, _ *ArgCtx) { 53 if _, ok := arg.Type().(*StructType); ok { 54 for _, field := range arg.(*GroupArg).Inner { 55 parentsMap[InnerArg(field)] = arg 56 } 57 } 58 }) 59 60 csumMap := make(map[Arg]CsumInfo) 61 csumUses := make(map[Arg]struct{}) 62 63 // Calculate generic inet checksums. 64 for _, arg := range inetCsumFields { 65 typ, _ := arg.Type().(*CsumType) 66 csummedArg := findCsummedArg(arg, typ, parentsMap) 67 csumUses[csummedArg] = struct{}{} 68 chunk := CsumChunk{CsumChunkArg, csummedArg, 0, 0} 69 csumMap[arg] = CsumInfo{Kind: CsumInet, Chunks: []CsumChunk{chunk}} 70 } 71 72 // No need to continue if there are no pseudo csum fields. 73 if len(pseudoCsumFields) == 0 { 74 return csumMap, csumUses 75 } 76 77 // Extract ipv4 or ipv6 source and destination addresses. 78 var ipSrcAddr, ipDstAddr Arg 79 ForeachArg(c, func(arg Arg, _ *ArgCtx) { 80 groupArg, ok := arg.(*GroupArg) 81 if !ok { 82 return 83 } 84 // syz_csum_* structs are used in tests 85 switch groupArg.Type().TemplateName() { 86 case "ipv4_header", "syz_csum_ipv4_header": 87 ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 4) 88 case "ipv6_packet_t", "syz_csum_ipv6_header": 89 ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 16) 90 } 91 }) 92 if ipSrcAddr == nil || ipDstAddr == nil { 93 panic("no ipv4 nor ipv6 header found") 94 } 95 96 // Calculate pseudo checksums. 97 for _, arg := range pseudoCsumFields { 98 typ, _ := arg.Type().(*CsumType) 99 csummedArg := findCsummedArg(arg, typ, parentsMap) 100 protocol := uint8(typ.Protocol) 101 var info CsumInfo 102 if ipSrcAddr.Size() == 4 { 103 info = composePseudoCsumIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol) 104 } else { 105 info = composePseudoCsumIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol) 106 } 107 csumMap[arg] = info 108 csumUses[csummedArg] = struct{}{} 109 csumUses[ipSrcAddr] = struct{}{} 110 csumUses[ipDstAddr] = struct{}{} 111 } 112 113 return csumMap, csumUses 114 } 115 116 func findCsummedArg(arg Arg, typ *CsumType, parentsMap map[Arg]Arg) Arg { 117 if typ.Buf == ParentRef { 118 csummedArg := parentsMap[arg] 119 if csummedArg == nil { 120 panic(fmt.Sprintf("%q for %q is not in parents map", ParentRef, typ.Name())) 121 } 122 return csummedArg 123 } 124 for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] { 125 // TODO(dvyukov): support template argument names as in size calculation. 126 if typ.Buf == parent.Type().Name() { 127 return parent 128 } 129 } 130 panic(fmt.Sprintf("csum field %q references non existent field %q", typ.Name(), typ.Buf)) 131 } 132 133 func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo { 134 info := CsumInfo{Kind: CsumInet} 135 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0}) 136 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0}) 137 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(protocol))), 2}) 138 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(tcpPacket.Size()))), 2}) 139 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0}) 140 return info 141 } 142 143 func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo { 144 info := CsumInfo{Kind: CsumInet} 145 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0}) 146 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0}) 147 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(tcpPacket.Size()))), 4}) 148 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(protocol))), 4}) 149 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0}) 150 return info 151 } 152 153 func extractHeaderParams(arg *GroupArg, size uint64) (Arg, Arg) { 154 srcAddr := getFieldByName(arg, "src_ip") 155 dstAddr := getFieldByName(arg, "dst_ip") 156 if srcAddr.Size() != size || dstAddr.Size() != size { 157 panic(fmt.Sprintf("src/dst_ip fields in %v must be %v bytes", arg.Type().Name(), size)) 158 } 159 return srcAddr, dstAddr 160 } 161 162 func getFieldByName(arg *GroupArg, name string) Arg { 163 typ := arg.Type().(*StructType) 164 for i, field := range arg.Inner { 165 if typ.Fields[i].Name == name { 166 return field 167 } 168 } 169 panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type().Name())) 170 }