github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp (about) 1 //===- UniformSupport.cpp - Support utilities for uniform quant -----------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Dialect/QuantOps/UniformSupport.h" 19 #include "mlir/IR/StandardTypes.h" 20 21 using namespace mlir; 22 using namespace mlir::quant; 23 24 static bool isQuantizablePrimitiveType(Type inputType) { 25 return inputType.isa<FloatType>(); 26 } 27 28 const ExpressedToUniformQuantizedConverter 29 ExpressedToUniformQuantizedConverter::forInputType(Type inputType) { 30 switch (inputType.getKind()) { 31 default: 32 if (isQuantizablePrimitiveType(inputType)) { 33 // Supported primitive type (which just is the expressed type). 34 return ExpressedToUniformQuantizedConverter{inputType, inputType}; 35 } 36 // Unsupported. 37 return ExpressedToUniformQuantizedConverter{inputType, nullptr}; 38 case StandardTypes::RankedTensor: 39 case StandardTypes::UnrankedTensor: 40 case StandardTypes::Vector: { 41 Type elementType = inputType.cast<ShapedType>().getElementType(); 42 if (!isQuantizablePrimitiveType(elementType)) { 43 // Unsupported. 44 return ExpressedToUniformQuantizedConverter{inputType, nullptr}; 45 } 46 return ExpressedToUniformQuantizedConverter{ 47 inputType, inputType.cast<ShapedType>().getElementType()}; 48 } 49 } 50 } 51 52 Type ExpressedToUniformQuantizedConverter::convert( 53 UniformQuantizedType elementalType) const { 54 assert(expressedType && "convert() on unsupported conversion"); 55 56 switch (inputType.getKind()) { 57 default: 58 if (isQuantizablePrimitiveType(elementalType)) { 59 // For primitives, just use the new elemental type. 60 return elementalType; 61 } 62 // Unsupported. 63 return nullptr; 64 case StandardTypes::RankedTensor: 65 return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(), 66 elementalType); 67 case StandardTypes::UnrankedTensor: 68 return UnrankedTensorType::get(elementalType); 69 case StandardTypes::Vector: 70 return VectorType::get(inputType.cast<VectorType>().getShape(), 71 elementalType); 72 } 73 }