go.uber.org/yarpc@v1.72.1/inject.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package yarpc 22 23 import ( 24 "fmt" 25 "reflect" 26 27 "go.uber.org/yarpc/api/transport" 28 ) 29 30 var ( 31 // _clientBuilders is a map from type of our desired client 'T' to a 32 // (reflected) function with one of the following signatures, 33 // 34 // func(transport.ClientConfig) T 35 // func(transport.ClientConfig, reflect.StructField) T 36 // 37 // Where T is the same as the key type for that entry. 38 _clientBuilders = make(map[reflect.Type]reflect.Value) 39 40 _typeOfClientConfig = reflect.TypeOf((*transport.ClientConfig)(nil)).Elem() 41 _typeOfStructField = reflect.TypeOf(reflect.StructField{}) 42 ) 43 44 func validateClientBuilder(f interface{}) reflect.Value { 45 if f == nil { 46 panic("must not be nil") 47 } 48 49 fv := reflect.ValueOf(f) 50 ft := fv.Type() 51 switch { 52 case ft.Kind() != reflect.Func: 53 panic(fmt.Sprintf("must be a function, not %v", ft)) 54 55 // Validate number of arguments and results 56 case ft.NumIn() == 0: 57 panic("must accept at least one argument") 58 case ft.NumIn() > 2: 59 panic(fmt.Sprintf("must accept at most two arguments, got %v", ft.NumIn())) 60 case ft.NumOut() != 1: 61 panic(fmt.Sprintf("must return exactly one result, got %v", ft.NumOut())) 62 63 // Validate input and output types 64 case ft.In(0) != _typeOfClientConfig: 65 panic(fmt.Sprintf("must accept a transport.ClientConfig as its first argument, got %v", ft.In(0))) 66 case ft.NumIn() == 2 && ft.In(1) != _typeOfStructField: 67 panic(fmt.Sprintf("if a second argument is accepted, it must be a reflect.StructField, got %v", ft.In(1))) 68 case ft.Out(0).Kind() != reflect.Interface: 69 panic(fmt.Sprintf("must return a single interface type as a result, got %v", ft.Out(0).Kind())) 70 } 71 72 return fv 73 } 74 75 // RegisterClientBuilder registers a builder function for a specific client 76 // type. 77 // 78 // Functions must have one of the following signatures: 79 // 80 // func(transport.ClientConfig) T 81 // func(transport.ClientConfig, reflect.StructField) T 82 // 83 // Where T is the type of the client. T MUST be an interface. In the second 84 // form, the function receives type information about the field being filled. 85 // It may inspect the struct tags to customize its behavior. 86 // 87 // This function panics if a client for the given type has already been 88 // registered. 89 // 90 // After a builder function for a client type is registered, these objects can 91 // be instantiated automatically using InjectClients. 92 // 93 // A function to unregister the builder function is returned. Note that the 94 // function will clear whatever the corresponding type's builder function is 95 // at the time it is called, regardless of whether the value matches what was 96 // passed to this function or not. 97 func RegisterClientBuilder(f interface{}) (forget func()) { 98 fv := validateClientBuilder(f) 99 t := fv.Type().Out(0) 100 101 if _, conflict := _clientBuilders[t]; conflict { 102 panic(fmt.Sprintf("a builder for %v has already been registered", t)) 103 } 104 105 _clientBuilders[t] = fv 106 return func() { delete(_clientBuilders, t) } 107 } 108 109 // InjectClients injects clients from a Dispatcher into the given struct. dest 110 // must be a pointer to a struct with zero or more exported fields which hold 111 // YARPC client types. This includes json.Client, raw.Client, and any 112 // generated Thrift service client. Fields with nil values and a `service` tag 113 // will be populated with clients using that service`s ClientConfig. 114 // 115 // Given, 116 // 117 // type Handler struct { 118 // KeyValueClient keyvalueclient.Interface `service:"keyvalue"` 119 // UserClient json.Client `service:"users"` 120 // TagClient tagclient.Interface // no tag; will be left unchanged 121 // } 122 // 123 // The call, 124 // 125 // var h Handler 126 // yarpc.InjectClients(dispatcher, &h) 127 // 128 // Is equivalent to, 129 // 130 // var h Handler 131 // h.KeyValueClient = keyvalueclient.New(dispatcher.ClientConfig("keyvalue")) 132 // h.UserClient = json.New(dispatcher.ClientConfig("users")) 133 // 134 // Builder functions for different client types may be registered using the 135 // RegisterClientBuilder function. 136 // 137 // This function panics if a field with an unknown type and nil value has the 138 // `service` tag. 139 func InjectClients(src transport.ClientConfigProvider, dest interface{}) { 140 destV := reflect.ValueOf(dest) 141 destT := reflect.TypeOf(dest) 142 if destT.Kind() != reflect.Ptr || destT.Elem().Kind() != reflect.Struct { 143 panic(fmt.Sprintf("dest must be a pointer to a struct, not %T", dest)) 144 } 145 146 structV := destV.Elem() 147 structT := destT.Elem() 148 for i := 0; i < structV.NumField(); i++ { 149 fieldInfo := structT.Field(i) 150 fieldV := structV.Field(i) 151 152 if !fieldV.CanSet() { 153 continue 154 } 155 156 fieldT := fieldInfo.Type 157 if fieldT.Kind() != reflect.Interface { 158 continue 159 } 160 161 service := fieldInfo.Tag.Get("service") 162 if service == "" { 163 continue 164 } 165 166 if !fieldV.IsNil() { 167 continue 168 } 169 170 builder, ok := _clientBuilders[fieldT] 171 if !ok { 172 panic(fmt.Sprintf("a constructor for %v has not been registered", fieldT)) 173 } 174 builderT := builder.Type() 175 176 args := make([]reflect.Value, 1, builderT.NumIn()) 177 args[0] = reflect.ValueOf(src.ClientConfig(service)) 178 if builderT.NumIn() > 1 { 179 args = append(args, reflect.ValueOf(fieldInfo)) 180 } 181 182 client := builder.Call(args)[0] 183 fieldV.Set(client) 184 } 185 }