简介
TensorFlow Serving可以将离线训练好的机器学习模型轻松部署到线上,使用gRPC作为接口供外部调用。并且TensorFlow Serving可以支持模型热更新与自动模型版本管理,这可以让算法工作者将工作重心放在离线模型的效果优化上,而不用为线上服务操心。
本文就介绍如何用TensorFlow Serving搭建线性回归预测服务,当然针对这个线性回归任务你可以把训练好的w、b两个参数直接写到代码里,本文只是用简单例子做入门。
环境准备
下面列出本次实验的环境:
- bazel 0.11.0:用于编译部署模型
- java 1.8:用于线上服务
- python 2.7: 用于训练模型
- tensorflow 1.5.0: 用于训练模型
- gRPC:用于接口调用
- gcc 4.8.5:编译代码
模型训练和保存
环境准备好后,接下来我们用python写一个训练和保存模型的代码train.py。
模型代码:
#!/usr/bin/env python import numpy as np import tensorflow as tf import tensorflow.contrib.session_bundle.exporter as exporter # Generate input data n_samples = 1000 x_data = np.arange(100, step=.1) y_data = x_data + 20 * np.sin(x_data / 10) x_data = np.reshape(x_data, (n_samples, 1)) y_data = np.reshape(y_data, (n_samples, 1)) sample = 1000 learning_rate = 0.01 batch_size = 100 n_steps = 500 # Placeholders for batched input x = tf.placeholder(tf.float32, shape=(batch_size, 1)) y = tf.placeholder(tf.float32, shape=(batch_size, 1)) with tf.variable_scope('test'): w = tf.get_variable('weights', (1, 1), initializer=tf.random_normal_initializer()) b = tf.get_variable('bias', (1,), initializer=tf.constant_initializer(0)) y_pred = tf.matmul(x, w) + b loss = tf.reduce_sum((y - y_pred) ** 2 / n_samples) opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for _ in range(n_steps): indices = np.random.choice(n_samples, batch_size) x_batch = x_data[indices] y_batch = y_data[indices] _, loss_val = sess.run([opt, loss], feed_dict={x:x_batch, y:y_batch}) print(w.eval()) print(b.eval()) print(loss_val) saver = tf.train.Saver() model_exporter = exporter.Exporter(saver) model_exporter.init( sess.graph.as_graph_def(), named_graph_signatures={ 'inputs': exporter.generic_signature({'x': x}), 'outputs': exporter.generic_signature({'y': y_pred})}) model_exporter.export("/tmp/linear-regression/", tf.constant("1"), sess)
运行代码:
python train.py
模型结果
[[1.0108936]] [1.9290849] 19.266457
/tmp/linear-regression/ 目录下有以下文件(我训练了两次,第二次的版本号指定为2,所以有两个文件夹):
.
├── 00000001
│ ├── checkpoint
│ ├── export.data-00000-of-00001
│ ├── export.index
│ └── export.meta
└── 00000002
├── checkpoint
├── export.data-00000-of-00001
├── export.index
└── export.meta
模型部署
# 下载tensorflow_serving源码 git clone https://github.com/tensorflow/serving.git # 编译tensorflow_model_server bazel build //tensorflow_serving/model_servers:tensorflow_model_server # 启动服务 bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=test --model_base_path=/tmp/linear-regression
Running ModelServer at 0.0.0.0:9000 ...
请求服务
模型部署完后,接下来我们用java来编写线上请求服务的代码:
ManagedChannel channel = null; try { channel = ManagedChannelBuilder.forAddress("服务所部属的ip地址", 9000).usePlaintext(true).build(); PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel); Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder(); Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder(); modelSpecBuilder.setName("test"); predictRequestBuilder.setModelSpec(modelSpecBuilder); TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder(); tensorProtoBuilder.setDtype(DataType.DT_FLOAT); TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1)); tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1)); tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build()); List<Float> floatList = new ArrayList<>(); Random random = new Random(); float x = random.nextFloat(); floatList.add(x); tensorProtoBuilder.addAllFloatVal(floatList); predictRequestBuilder.putInputs("x", tensorProtoBuilder.build()); Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build()); LOG.debug("x={}, result={}", x, predictResponse.getOutputsOrThrow("y").getFloatValList().toString()); } catch (StatusRuntimeException e) { LOG.error("StatusRuntimeException: ", e); } finally { if (channel != null) { channel.shutdown(); } }
依赖代码:
<dependency> <groupId>com.yesup.oss</groupId> <artifactId>tensorflow-client</artifactId> <version>1.4-2</version> </dependency>
请求后:
x=0.033170342, result=[1.9626166]
x=0.90723795, result=[2.846206]
模型更新
更改train.py里的版本号,比如1改为2,然后执行。
执行完后,模型会进行自动更新。
参数如下:
[[1.0106797]] [2.5606086] 20.278612
重新请求后:
x=0.8308283, result=[3.40031]
参考
参考了不少博客,这里只列出两篇实用的。