
按照官方网站的指导在项目的模块的构建文件build.gradle中配置中增加如下配置:
implementation 'org.tensorflow:tensorflow-lite:2.7.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.7.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'
implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'
android{
aaptOptions {
noCompress "tflite"
}
defaultConfig {
ndk {
abiFilters 'armeabi-v7a', 'arm64-v8a'
}
}
}
导入模型资源资源
创建将文《关于将Tesorflow的SavedModel模型转换成tflite模型》创建的模型model.tflite,导入到Android项目的assets目录中。
定义模型基本配置类baseModelConfig
public abstract class baseModelConfig{
//每通道处理的字节数
var numBytesPerChannel:Int = 0
//定义批处理的个数
var dimBatchSize:Int = 0
//定义像素个数
var dimPixelSize:Int = 0
//定义图片的宽度
var dimImgWidth:Int = 0
//定义图片的高度
var dimImgHeight:Int = 0
//定义平均差
var imageMean=0
//定义图片的标准差
var imageSTD:Float = 0.0F
//定义模型的名称
lateinit var modelName:String
constructor() : super() {
setConfigs()
}
public abstract fun addImgValue(buffer: ByteBuffer,pixel:Int)
public abstract fun setConfigs()
}
定义FloatSavedModelConfig类
class FloatSavedModelConfig: baseModelConfig() {
public override fun setConfigs() {
modelName="model.tflite"
numBytesPerChannel = 4
dimBatchSize = 1
dimPixelSize = 1
dimImgWidth = 28
dimImgHeight = 28
imageMean = 0
imageSTD = 255.0f
}
override fun addImgValue(imgdata: ByteBuffer, pixel: Int) {
imgData.putFloat(((pixel and 0xFF) - imageMean) / imageSTD)
}
}
创建配置模型参数的工厂类
object ModelConfigFactory {
const val FLOAT_SAVED_MODEL = "float_saved_model"
const val QUANT_SAVED_MODEL = "quant_saved_model"
fun getModelConfig(model: String): baseModelConfig? =
when(model) {
FLOAT_SAVED_MODEL-> FloatSavedModelConfig()
QUANT_SAVED_MODEL-> QuantSavedModelConfig()
else->null
}
}
定义图像分类器
class ImageClassifier {
private val TAG = "FashionMNIST"
private val RESULTS_TO_SHOW = 3
lateinit var mTFLite: Interpreter
lateinit var mModelPath:String
var mNumBytesPerChannel = 0
var mDimBatchSize = 0
var mDimPixelSize = 0
var mDimImgWidth = 0
var mDimImgHeight = 0
lateinit var mModelConfig:baseModelConfig
//定义标签检测的二维数组1x10
val mLabelProbArray = Array(1) {
FloatArray(
10
)
}
val labels = arrayListOf("T恤","裤子","帽头衫","连衣裙","外套","凉鞋","衬衫","运动鞋","包","靴子")
//定义检测结果保持到优先队列中
var mSortedLabels = PriorityQueue>(
RESULTS_TO_SHOW) {
o1, o2 -> o1?.value!!.compareTo(o2?.value!!)
}
private fun initConfig(config: baseModelConfig) {
mModelConfig = config
mNumBytesPerChannel = config.numBytesPerChannel
mDimBatchSize = config.dimBatchSize
mDimPixelSize = config.dimPixelSize
mDimImgWidth = config.dimImgWidth
mDimImgHeight = config.dimImgHeight
mModelPath = config.modelName
}
constructor(modelConfig: String, activity: Activity) {
// 初始化分类器的相关参数
initConfig(ModelConfigFactory.getModelConfig(modelConfig)!!)
// 使用配置参数初始化翻译器
mTFLite = Interpreter(loadModelFile(activity)!!)
}
private fun loadModelFile(activity: Activity): MappedByteBuffer? {
val fileDescriptor = activity.assets.openFd(mModelPath)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
protected fun convertBitmapToByteBuffer(bitmap: Bitmap?): ByteBuffer {
val intValues = IntArray(mDimImgWidth * mDimImgHeight)
//调整要处理的图片为28x28
var tmp = scaleBitmap(bitmap)
//将图片二值化
tmp = binarized(tmp)
//将二值化的图片加载到内存中
tmp.getPixels(intValues,
0, tmp.width, 0, 0, tmp.width, tmp.height
)
val imgData = ByteBuffer.allocateDirect(
mNumBytesPerChannel * mDimBatchSize * mDimImgWidth * mDimImgHeight * mDimPixelSize
)
imgData.order(ByteOrder.nativeOrder())
imgData.rewind()
//将图片转换成像素实数数据
var pixel = 0
for (i in 0 until mDimImgWidth) {
for (j in 0 until mDimImgHeight) {
var value = intValues[pixel++]
mModelConfig.addImgValue(imgData, value)
}
}
return imgData
}
fun binarized(bmp: Bitmap): Bitmap {
val width = bmp.width
val height = bmp.height
val pixels = IntArray(width * height)
//将图片的像素加载到数组中
bmp.getPixels(pixels, 0, width, 0, 0, width, height)
var alpha = 0xFF shl 24
for (i in 0 until height) {
for (j in 0 until width) {
val grey = pixels[width * i + j]
// 分离三原色
alpha = grey and -0x1000000 shr 24
var red = grey and 0x00FF0000 shr 16
var green = grey and 0x0000FF00 shr 8
var blue = grey and 0x000000FF
val tmp = 180
red = if (red > tmp) 255 else 0
blue = if (blue > tmp) 255 else 0
green = if (green > tmp) 255 else 0
pixels[width * i + j] = alpha shl 24 or (red shl 16) or (green shl 8) or blue
if (pixels[width * i + j] == -1) {
pixels[width * i + j] = -1
} else {
pixels[width * i + j] = -16777216
}
}
}
// 新建图片
val newBmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
// 设置图片数据
newBmp.setPixels(pixels, 0, width, 0, 0, width, height)
return newBmp
}
fun scaleBitmap(bmp: Bitmap?): Bitmap {
return Bitmap.createScaledBitmap(bmp!!, mDimImgWidth, mDimImgHeight, true)
}
fun doClassify(bitmap: Bitmap?): String? {
// 将Bitmap图片转换成TFLite翻译器的可读的ByteBuffer
val imgData = convertBitmapToByteBuffer(bitmap)
// do run interpreter
val startTime = System.nanoTime()
mTFLite.run(imgData, mLabelProbArray)
val endTime = System.nanoTime()
Log.i(TAG, String.format(
"运行识别的时间: %f ms",
(endTime - startTime).toFloat() / 1000000.0f
)
)
// 生成并返回结果
return printTopKLabels()
}
fun printTopKLabels(): String? {
for (i in 0..9) {
mSortedLabels.add(
AbstractMap.SimpleEntry(
labels[i],
mLabelProbArray[0][i]
)
)
if (mSortedLabels.size > RESULTS_TO_SHOW) {
mSortedLabels.poll()
}
}
val textToShow = StringBuffer()
val size = mSortedLabels.size
for (i in 0 until size) {
val label = mSortedLabels.poll()
textToShow.insert(0, String.format("n%s %4.8f", label.key, label.value))
}
return textToShow.toString()
}
}
定义主活动MainActivity
在主活动中,主要处理如下 *** 作:
(1)从图库中选择图片
(2)利用图像分类器检测图片中的内容,判断是FashionMnist数据集的哪种标签
(3)将检测的结果在移动终端的GUI界面中显示出来。
class MainActivity : AppCompatActivity() {
private lateinit var binding: ActivityMainBinding
val RequestCameraCode = 1
val TAG = "FashionMNIST"
companion object{
var mIsFloat = true
}
private var bitmap: Bitmap? = null
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
//生成视图绑定对象
binding = ActivityMainBinding.inflate(layoutInflater)
//设置视图的根视图
setContentView(binding.root)
binding.imageView.setonClickListener {
val intent = Intent()
intent.type = "image
private fun getChoices()= resources.getStringArray(R.array.model_names)
}
参考文献
李锡涵等 《简明的Tensorflow 2》人民邮电出版社 北京 P91-P96
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)