github.com/jackc/pgx/v5@v5.5.5/pgproto3/bind.go (about) 1 package pgproto3 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "encoding/hex" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "math" 11 12 "github.com/jackc/pgx/v5/internal/pgio" 13 ) 14 15 type Bind struct { 16 DestinationPortal string 17 PreparedStatement string 18 ParameterFormatCodes []int16 19 Parameters [][]byte 20 ResultFormatCodes []int16 21 } 22 23 // Frontend identifies this message as sendable by a PostgreSQL frontend. 24 func (*Bind) Frontend() {} 25 26 // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 27 // type identifier and 4 byte message length. 28 func (dst *Bind) Decode(src []byte) error { 29 *dst = Bind{} 30 31 idx := bytes.IndexByte(src, 0) 32 if idx < 0 { 33 return &invalidMessageFormatErr{messageType: "Bind"} 34 } 35 dst.DestinationPortal = string(src[:idx]) 36 rp := idx + 1 37 38 idx = bytes.IndexByte(src[rp:], 0) 39 if idx < 0 { 40 return &invalidMessageFormatErr{messageType: "Bind"} 41 } 42 dst.PreparedStatement = string(src[rp : rp+idx]) 43 rp += idx + 1 44 45 if len(src[rp:]) < 2 { 46 return &invalidMessageFormatErr{messageType: "Bind"} 47 } 48 parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) 49 rp += 2 50 51 if parameterFormatCodeCount > 0 { 52 dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) 53 54 if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { 55 return &invalidMessageFormatErr{messageType: "Bind"} 56 } 57 for i := 0; i < parameterFormatCodeCount; i++ { 58 dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) 59 rp += 2 60 } 61 } 62 63 if len(src[rp:]) < 2 { 64 return &invalidMessageFormatErr{messageType: "Bind"} 65 } 66 parameterCount := int(binary.BigEndian.Uint16(src[rp:])) 67 rp += 2 68 69 if parameterCount > 0 { 70 dst.Parameters = make([][]byte, parameterCount) 71 72 for i := 0; i < parameterCount; i++ { 73 if len(src[rp:]) < 4 { 74 return &invalidMessageFormatErr{messageType: "Bind"} 75 } 76 77 msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) 78 rp += 4 79 80 // null 81 if msgSize == -1 { 82 continue 83 } 84 85 if len(src[rp:]) < msgSize { 86 return &invalidMessageFormatErr{messageType: "Bind"} 87 } 88 89 dst.Parameters[i] = src[rp : rp+msgSize] 90 rp += msgSize 91 } 92 } 93 94 if len(src[rp:]) < 2 { 95 return &invalidMessageFormatErr{messageType: "Bind"} 96 } 97 resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) 98 rp += 2 99 100 dst.ResultFormatCodes = make([]int16, resultFormatCodeCount) 101 if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { 102 return &invalidMessageFormatErr{messageType: "Bind"} 103 } 104 for i := 0; i < resultFormatCodeCount; i++ { 105 dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) 106 rp += 2 107 } 108 109 return nil 110 } 111 112 // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 113 func (src *Bind) Encode(dst []byte) ([]byte, error) { 114 dst, sp := beginMessage(dst, 'B') 115 116 dst = append(dst, src.DestinationPortal...) 117 dst = append(dst, 0) 118 dst = append(dst, src.PreparedStatement...) 119 dst = append(dst, 0) 120 121 if len(src.ParameterFormatCodes) > math.MaxUint16 { 122 return nil, errors.New("too many parameter format codes") 123 } 124 dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) 125 for _, fc := range src.ParameterFormatCodes { 126 dst = pgio.AppendInt16(dst, fc) 127 } 128 129 if len(src.Parameters) > math.MaxUint16 { 130 return nil, errors.New("too many parameters") 131 } 132 dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) 133 for _, p := range src.Parameters { 134 if p == nil { 135 dst = pgio.AppendInt32(dst, -1) 136 continue 137 } 138 139 dst = pgio.AppendInt32(dst, int32(len(p))) 140 dst = append(dst, p...) 141 } 142 143 if len(src.ResultFormatCodes) > math.MaxUint16 { 144 return nil, errors.New("too many result format codes") 145 } 146 dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) 147 for _, fc := range src.ResultFormatCodes { 148 dst = pgio.AppendInt16(dst, fc) 149 } 150 151 return finishMessage(dst, sp) 152 } 153 154 // MarshalJSON implements encoding/json.Marshaler. 155 func (src Bind) MarshalJSON() ([]byte, error) { 156 formattedParameters := make([]map[string]string, len(src.Parameters)) 157 for i, p := range src.Parameters { 158 if p == nil { 159 continue 160 } 161 162 textFormat := true 163 if len(src.ParameterFormatCodes) == 1 { 164 textFormat = src.ParameterFormatCodes[0] == 0 165 } else if len(src.ParameterFormatCodes) > 1 { 166 textFormat = src.ParameterFormatCodes[i] == 0 167 } 168 169 if textFormat { 170 formattedParameters[i] = map[string]string{"text": string(p)} 171 } else { 172 formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} 173 } 174 } 175 176 return json.Marshal(struct { 177 Type string 178 DestinationPortal string 179 PreparedStatement string 180 ParameterFormatCodes []int16 181 Parameters []map[string]string 182 ResultFormatCodes []int16 183 }{ 184 Type: "Bind", 185 DestinationPortal: src.DestinationPortal, 186 PreparedStatement: src.PreparedStatement, 187 ParameterFormatCodes: src.ParameterFormatCodes, 188 Parameters: formattedParameters, 189 ResultFormatCodes: src.ResultFormatCodes, 190 }) 191 } 192 193 // UnmarshalJSON implements encoding/json.Unmarshaler. 194 func (dst *Bind) UnmarshalJSON(data []byte) error { 195 // Ignore null, like in the main JSON package. 196 if string(data) == "null" { 197 return nil 198 } 199 200 var msg struct { 201 DestinationPortal string 202 PreparedStatement string 203 ParameterFormatCodes []int16 204 Parameters []map[string]string 205 ResultFormatCodes []int16 206 } 207 err := json.Unmarshal(data, &msg) 208 if err != nil { 209 return err 210 } 211 dst.DestinationPortal = msg.DestinationPortal 212 dst.PreparedStatement = msg.PreparedStatement 213 dst.ParameterFormatCodes = msg.ParameterFormatCodes 214 dst.Parameters = make([][]byte, len(msg.Parameters)) 215 dst.ResultFormatCodes = msg.ResultFormatCodes 216 for n, parameter := range msg.Parameters { 217 dst.Parameters[n], err = getValueFromJSON(parameter) 218 if err != nil { 219 return fmt.Errorf("cannot get param %d: %w", n, err) 220 } 221 } 222 return nil 223 }