
问题解决背景
需要在spark2.xx上面调用TensorFlow2.x 训练的模型在本地环境运行 ; 生产在搭建的集群或者已集成的三方服务中运行。比如本文中介绍的使用基于阿里云的emr 、oss
运行环境:开发 windows; 生产 linux
1. tf模型准备
protobuf模型需保存为pb格式
model_path = "path/model"
model.save(model_path, save_format="tf")
2. spark 项目,pom添加相关依赖
shade打包添加 org.tensorflow tensorflow1.15.0 org.tensorflow*:* com.google.protobuf:*
3. 覆盖原有方法。具体见附录
4. 将model添加到spark file中
// 如果是本地运行,则不需要addFile, 直接load模型使用全路径即可 modelPath = "oss://xxxx//model_name
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)