github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp (about) 1 //===- LowerUniformRealMath.cpp ------------------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "UniformKernelUtils.h" 19 20 #include "mlir/Dialect/FxpMathOps/FxpMathOps.h" 21 #include "mlir/Dialect/FxpMathOps/Passes.h" 22 #include "mlir/Dialect/StandardOps/Ops.h" 23 #include "mlir/IR/Diagnostics.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 27 using namespace mlir; 28 using namespace mlir::fxpmath; 29 using namespace mlir::fxpmath::detail; 30 using namespace mlir::quant; 31 32 namespace { 33 34 struct LowerUniformRealMathPass 35 : public FunctionPass<LowerUniformRealMathPass> { 36 void runOnFunction() override; 37 }; 38 39 struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> { 40 void runOnFunction() override; 41 }; 42 43 } // end anonymous namespace 44 45 //===----------------------------------------------------------------------===// 46 // Dequantize 47 //===----------------------------------------------------------------------===// 48 49 static Value *emitUniformPerLayerDequantize(Location loc, Value *input, 50 UniformQuantizedType elementType, 51 PatternRewriter &rewriter) { 52 // Pre-conditions. 53 if (!elementType.isSigned()) { 54 // TODO: Support unsigned storage type. 55 emitWarning(loc, "unimplemented: dequantize signed uniform"); 56 return nullptr; 57 } 58 59 Type storageType = elementType.castToStorageType(input->getType()); 60 Type realType = elementType.castToExpressedType(input->getType()); 61 Type intermediateType = 62 castElementType(storageType, IntegerType::get(32, rewriter.getContext())); 63 assert(storageType && "cannot cast to storage type"); 64 assert(realType && "cannot cast to expressed type"); 65 66 // Cast to storage type. 67 input = rewriter.create<StorageCastOp>(loc, storageType, input); 68 69 // Promote to intermediate type. 70 input = rewriter.create<ConvertISOp>(loc, intermediateType, input); 71 72 // Apply zero-point offset. 73 if (elementType.getZeroPoint() != 0) { 74 Value *negZeroPointConst = rewriter.create<ConstantOp>( 75 loc, broadcastScalarConstIntValue(intermediateType, 76 -elementType.getZeroPoint())); 77 input = rewriter.create<AddIOp>(loc, input, negZeroPointConst); 78 } 79 80 // Convert to float. 81 input = rewriter.create<ConvertISToFOp>(loc, realType, input); 82 83 // Mul by scale. 84 Value *scaleConst = rewriter.create<ConstantOp>( 85 loc, broadcastScalarConstFloatValue(realType, 86 APFloat(elementType.getScale()))); 87 return rewriter.create<MulFOp>(loc, input, scaleConst); 88 } 89 90 static Value * 91 emitUniformPerAxisDequantize(Location loc, Value *input, 92 UniformQuantizedPerAxisType elementType, 93 PatternRewriter &rewriter) { 94 // TODO: Support per-axis dequantize. 95 rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning) 96 << "unimplemented: per-axis uniform dequantization"; 97 return nullptr; 98 } 99 100 static Value *emitDequantize(Location loc, Value *input, 101 PatternRewriter &rewriter) { 102 Type inputType = input->getType(); 103 QuantizedType qElementType = 104 QuantizedType::getQuantizedElementType(inputType); 105 if (auto uperLayerElementType = 106 qElementType.dyn_cast_or_null<UniformQuantizedType>()) { 107 return emitUniformPerLayerDequantize(loc, input, uperLayerElementType, 108 rewriter); 109 } else if (auto uperAxisElementType = 110 qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) { 111 return emitUniformPerAxisDequantize(loc, input, uperAxisElementType, 112 rewriter); 113 } else { 114 return nullptr; 115 } 116 } 117 118 namespace { 119 120 struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> { 121 using OpRewritePattern<DequantizeCastOp>::OpRewritePattern; 122 123 PatternMatchResult matchAndRewrite(DequantizeCastOp op, 124 PatternRewriter &rewriter) const { 125 Type inputType = op.arg()->getType(); 126 Type outputType = op.getResult()->getType(); 127 128 QuantizedType inputElementType = 129 QuantizedType::getQuantizedElementType(inputType); 130 Type expressedOutputType = inputElementType.castToExpressedType(inputType); 131 if (expressedOutputType != outputType) { 132 // Not a valid uniform cast. 133 return matchFailure(); 134 } 135 136 Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); 137 if (!dequantizedValue) { 138 return matchFailure(); 139 } 140 141 rewriter.replaceOp(op, dequantizedValue); 142 return matchSuccess(); 143 } 144 }; 145 146 } // end anonymous namespace 147 148 //===----------------------------------------------------------------------===// 149 // Elementwise add 150 //===----------------------------------------------------------------------===// 151 152 static LogicalResult 153 tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, 154 PatternRewriter &rewriter) { 155 if (!info.resultType.isSigned() || info.lhsType != info.resultType || 156 info.rhsType != info.resultType) { 157 return failure(); 158 } 159 160 // Choose a byte aligned intermediate width big enough to perform the 161 // calculation without overflow. 162 // TODO: This should probably be made just big enough to avoid overflow and 163 // leave the downstream tooling to decide how to align that to machine 164 // word sizes. 165 unsigned intermediateWidth = 166 info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32; 167 IntegerType intermediateElementType = 168 IntegerType::get(intermediateWidth, rewriter.getContext()); 169 Type intermediateType = 170 castElementType(info.resultStorageType, intermediateElementType); 171 172 // Cast operands to storage type. 173 Value *lhsValue = rewriter 174 .create<StorageCastOp>(info.op->getLoc(), 175 info.lhsStorageType, info.lhs) 176 .getResult(); 177 Value *rhsValue = rewriter 178 .create<StorageCastOp>(info.op->getLoc(), 179 info.rhsStorageType, info.rhs) 180 .getResult(); 181 182 // Cast to the intermediate sized type. 183 lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType, 184 lhsValue); 185 rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType, 186 rhsValue); 187 188 // Add. 189 Value *resultValue = 190 rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue); 191 192 // Zero point offset adjustment. 193 // result = (lhs - zp) + (rhs - zp) + zp 194 // zpOffset = -zp 195 int zpOffset = -1 * info.resultType.getZeroPoint(); 196 if (zpOffset != 0) { 197 Value *zpOffsetConst = rewriter.create<ConstantOp>( 198 info.op->getLoc(), 199 broadcastScalarConstIntValue(intermediateType, zpOffset)); 200 resultValue = 201 rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst); 202 } 203 204 // Clamp. 205 auto clampMinMax = info.getClampMinMax(intermediateElementType); 206 resultValue = rewriter.create<ClampISOp>( 207 info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second); 208 209 // Convert back to original type. 210 resultValue = rewriter.create<ConvertISOp>( 211 info.op->getLoc(), info.resultStorageType, resultValue); 212 213 // Cast back for new result. 214 rewriter.replaceOpWithNewOp<StorageCastOp>( 215 info.op, info.getQuantizedResultType(), resultValue); 216 217 return success(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // Elementwise mul 222 //===----------------------------------------------------------------------===// 223 224 static LogicalResult 225 tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, 226 PatternRewriter &rewriter) { 227 if (!info.resultType.isSigned()) { 228 return failure(); 229 } 230 231 double outputMultiplierReal = info.lhsType.getScale() * 232 info.rhsType.getScale() / 233 info.resultType.getScale(); 234 if (outputMultiplierReal > 1.0) { 235 info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0"); 236 return failure(); 237 } 238 239 // TODO: Choose an appropriate intermediate width for muls > 8 bits to 240 // avoid overflow. 241 unsigned intermediateWidth = 32; 242 IntegerType intermediateElementType = 243 IntegerType::get(intermediateWidth, rewriter.getContext()); 244 Type intermediateType = 245 castElementType(info.resultStorageType, intermediateElementType); 246 247 // Cast operands to storage type. 248 Value *lhsValue = rewriter 249 .create<StorageCastOp>(info.op->getLoc(), 250 info.lhsStorageType, info.lhs) 251 .getResult(); 252 Value *rhsValue = rewriter 253 .create<StorageCastOp>(info.op->getLoc(), 254 info.rhsStorageType, info.rhs) 255 .getResult(); 256 257 // Cast to the intermediate sized type. 258 lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType, 259 lhsValue); 260 rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType, 261 rhsValue); 262 263 // Apply argument zeroPoints. 264 if (info.lhsType.getZeroPoint() != 0) { 265 Value *zpOffsetConst = rewriter.create<ConstantOp>( 266 info.op->getLoc(), broadcastScalarConstIntValue( 267 intermediateType, -info.lhsType.getZeroPoint())); 268 lhsValue = 269 rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst); 270 } 271 272 if (info.rhsType.getZeroPoint() != 0) { 273 Value *zpOffsetConst = rewriter.create<ConstantOp>( 274 info.op->getLoc(), broadcastScalarConstIntValue( 275 intermediateType, -info.rhsType.getZeroPoint())); 276 rhsValue = 277 rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst); 278 } 279 280 // Mul. 281 Value *resultValue = 282 rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue); 283 284 // Scale output. 285 QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal); 286 resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>( 287 info.op->getLoc(), resultValue, 288 IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier)); 289 resultValue = rewriter.create<RoundingDivideByPotISOp>( 290 info.op->getLoc(), resultValue, 291 IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent)); 292 293 // Zero point offset adjustment. 294 if (info.resultType.getZeroPoint() != 0) { 295 Value *zpOffsetConst = rewriter.create<ConstantOp>( 296 info.op->getLoc(), 297 broadcastScalarConstIntValue(intermediateType, 298 info.resultType.getZeroPoint())); 299 resultValue = 300 rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst); 301 } 302 303 // Clamp. 304 auto clampMinMax = info.getClampMinMax(intermediateElementType); 305 resultValue = rewriter.create<ClampISOp>( 306 info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second); 307 308 // Convert back to original type. 309 resultValue = rewriter.create<ConvertISOp>( 310 info.op->getLoc(), info.resultStorageType, resultValue); 311 312 // Cast back for new result. 313 rewriter.replaceOpWithNewOp<StorageCastOp>( 314 info.op, info.getQuantizedResultType(), resultValue); 315 316 return success(); 317 } 318 319 namespace { 320 321 struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> { 322 using OpRewritePattern<RealAddEwOp>::OpRewritePattern; 323 324 PatternMatchResult matchAndRewrite(RealAddEwOp op, 325 PatternRewriter &rewriter) const { 326 const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), 327 op.clamp_max()); 328 if (!info.isValid()) { 329 return matchFailure(); 330 } 331 332 // Try all of the permutations we support. 333 if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) { 334 return matchSuccess(); 335 } 336 337 return matchFailure(); 338 } 339 }; 340 341 struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> { 342 using OpRewritePattern<RealMulEwOp>::OpRewritePattern; 343 344 PatternMatchResult matchAndRewrite(RealMulEwOp op, 345 PatternRewriter &rewriter) const { 346 const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), 347 op.clamp_max()); 348 if (!info.isValid()) { 349 return matchFailure(); 350 } 351 352 // Try all of the permutations we support. 353 if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) { 354 return matchSuccess(); 355 } 356 357 return matchFailure(); 358 } 359 }; 360 361 } // end anonymous namespace 362 363 //===----------------------------------------------------------------------===// 364 // LowerUniformRealMath pass 365 //===----------------------------------------------------------------------===// 366 367 void LowerUniformRealMathPass::runOnFunction() { 368 auto fn = getFunction(); 369 OwningRewritePatternList patterns; 370 auto *context = &getContext(); 371 patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context); 372 applyPatternsGreedily(fn, patterns); 373 } 374 375 FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() { 376 return new LowerUniformRealMathPass(); 377 } 378 379 static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass( 380 "fxpmath-lower-uniform-real-math", 381 "Lowers uniform-quantized real math ops to integer arithmetic."); 382 383 //===----------------------------------------------------------------------===// 384 // LowerUniformCasts pass 385 //===----------------------------------------------------------------------===// 386 387 void LowerUniformCastsPass::runOnFunction() { 388 auto fn = getFunction(); 389 OwningRewritePatternList patterns; 390 auto *context = &getContext(); 391 patterns.insert<UniformDequantizePattern>(context); 392 applyPatternsGreedily(fn, patterns); 393 } 394 395 FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() { 396 return new LowerUniformCastsPass(); 397 } 398 399 static PassRegistration<LowerUniformCastsPass> 400 lowerUniformCastsPass("fxpmath-lower-uniform-casts", 401 "Lowers uniform-quantized casts.");