github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/inet.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package datacodec 16 17 import ( 18 "errors" 19 "net" 20 21 "github.com/datastax/go-cassandra-native-protocol/datatype" 22 "github.com/datastax/go-cassandra-native-protocol/primitive" 23 ) 24 25 // Inet is a codec for the CQL inet type. Its preferred Go type is net.IP but it can encode from and decode to 26 // []byte as well. 27 var Inet Codec = &inetCodec{} 28 29 type inetCodec struct{} 30 31 func (c *inetCodec) DataType() datatype.DataType { 32 return datatype.Inet 33 } 34 35 func (c *inetCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { 36 var val net.IP 37 if val, err = convertToIP(source); err == nil && val != nil { 38 dest, err = writeInet(val) 39 } 40 if err != nil { 41 err = errCannotEncode(source, c.DataType(), version, err) 42 } 43 return 44 } 45 46 func (c *inetCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { 47 var val net.IP 48 if val, wasNull, err = readInet(source); err == nil { 49 err = convertFromIP(val, wasNull, dest) 50 } 51 if err != nil { 52 err = errCannotDecode(dest, c.DataType(), version, err) 53 } 54 return 55 } 56 57 func convertToIP(source interface{}) (val net.IP, err error) { 58 switch s := source.(type) { 59 case net.IP: 60 val = s 61 val4 := val.To4() 62 if val4 != nil { 63 val = val4 64 } 65 case *net.IP: 66 if s != nil { 67 val = *s 68 val4 := val.To4() 69 if val4 != nil { 70 val = val4 71 } 72 } 73 case []byte: 74 val = s 75 val4 := val.To4() 76 if val4 != nil { 77 val = val4 78 } 79 case *[]byte: 80 if s != nil { 81 val = *s 82 val4 := val.To4() 83 if val4 != nil { 84 val = val4 85 } 86 } 87 case string: 88 val = compactV4(net.ParseIP(s)) 89 if val == nil { 90 err = errCannotParseString(s, errors.New("net.ParseIP(text) failed")) 91 } 92 case *string: 93 if s != nil { 94 val = compactV4(net.ParseIP(*s)) 95 if val == nil { 96 err = errCannotParseString(*s, errors.New("net.ParseIP(text) failed")) 97 } 98 } 99 case nil: 100 default: 101 err = ErrConversionNotSupported 102 } 103 if err != nil { 104 err = errSourceConversionFailed(source, val, err) 105 } 106 return 107 } 108 109 func convertFromIP(val net.IP, wasNull bool, dest interface{}) (err error) { 110 switch d := dest.(type) { 111 case *interface{}: 112 if d == nil { 113 err = ErrNilDestination 114 } else if wasNull { 115 *d = nil 116 } else { 117 *d = val 118 } 119 case *net.IP: 120 if d == nil { 121 err = ErrNilDestination 122 } else if wasNull { 123 *d = nil 124 } else { 125 *d = compactV4(val) 126 } 127 case *[]byte: 128 if d == nil { 129 err = ErrNilDestination 130 } else if wasNull { 131 *d = nil 132 } else { 133 *d = compactV4(val) 134 } 135 case *string: 136 if d == nil { 137 err = ErrNilDestination 138 } else if wasNull { 139 *d = "" 140 } else { 141 *d = val.String() 142 } 143 default: 144 err = errDestinationInvalid(dest) 145 } 146 if err != nil { 147 err = errDestinationConversionFailed(val, dest, err) 148 } 149 return 150 } 151 152 func writeInet(val net.IP) (dest []byte, err error) { 153 length := len(val) 154 if length == 0 { 155 dest = nil 156 } else if length == net.IPv4len || length == net.IPv6len { 157 dest = compactV4(val) 158 } else { 159 err = errWrongFixedLengths(net.IPv4len, net.IPv6len, length) 160 } 161 if err != nil { 162 err = errCannotWrite(val, err) 163 } 164 return 165 } 166 167 // The below functions are roughly equivalent to primitive.ReadInetAddr and primitive.WriteInetAddr. 168 // They favor the compact form (4-byte slice) for IPv4 addresses. 169 170 func readInet(source []byte) (val net.IP, wasNull bool, err error) { 171 length := len(source) 172 if length == 0 { 173 wasNull = true 174 } else if length == net.IPv4len { 175 val = net.IPv4(source[0], source[1], source[2], source[3]).To4() 176 } else if length == net.IPv6len { 177 val = source 178 } else { 179 err = errWrongFixedLengths(net.IPv4len, net.IPv6len, length) 180 } 181 if err != nil { 182 err = errCannotRead(val, err) 183 } 184 return 185 } 186 187 func compactV4(val net.IP) net.IP { 188 if val != nil { 189 val4 := val.To4() 190 if val4 != nil { 191 return val4 192 } 193 } 194 return val 195 }