背景介紹
這篇文章是tensorflow serving java api使用的參考案例,基本上把TFS的核心API的用法都介紹清楚。案例主要分爲三部分:
- 動態更新模型:用於在TFS處於runtime時候動態加載模型。
- 獲取模型狀態:用於獲取加載的模型的基本信息。
- 在線模型預測:進行在線預測,分類等操作,着重介紹在線預測。
因爲模型的預測需要參考模型內部變量,所以可以先行通過TFS的REST接口獲取TF模型的元數據然後才能構建TFS的RPC請求對象。
TFS 使用入門
模型源數據獲取
curl http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]/metadata
說明:
- 參考TFS REST API
- 返回結果參考TF模型結構。
public static void getModelStatus() {
// 1、設置訪問的RPC協議的host和port
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
// 2、構建PredictionServiceBlockingStub對象
PredictionServiceGrpc.PredictionServiceBlockingStub predictionServiceBlockingStub =
PredictionServiceGrpc.newBlockingStub(channel);
// 3、設置待獲取的模型
Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder()
.setName("wdl_model").build();
// 4、構建獲取元數據的請求
GetModelMetadata.GetModelMetadataRequest modelMetadataRequest =
GetModelMetadata.GetModelMetadataRequest.newBuilder()
.setModelSpec(modelSpec)
.addAllMetadataField(Arrays.asList("signature_def"))
.build();
// 5、獲取元數據
GetModelMetadata.GetModelMetadataResponse getModelMetadataResponse =
predictionServiceBlockingStub.getModelMetadata(modelMetadataRequest);
channel.shutdownNow();
}
說明:
- Model.ModelSpec.newBuilder綁定需要訪問的模型的名字。
- GetModelMetadataRequest中addAllMetadataField綁定curl命令返回的metadata當中的
signature_def
字段。
動態更新模型
public static void addNewModel() {
// 1、構建動態更新模型1
ModelServerConfigOuterClass.ModelConfig modelConfig1 =
ModelServerConfigOuterClass.ModelConfig.newBuilder()
.setBasePath("/models/new_model")
.setName("new_model")
.setModelType(ModelServerConfigOuterClass.ModelType.TENSORFLOW)
.build();
// 2、構建動態更新模型2
ModelServerConfigOuterClass.ModelConfig modelConfig2 =
ModelServerConfigOuterClass.ModelConfig.newBuilder()
.setBasePath("/models/wdl_model")
.setName("wdl_model")
.setModelType(ModelServerConfigOuterClass.ModelType.TENSORFLOW)
.build();
// 3、合併動態更新模型到ModelConfigList對象中
ModelServerConfigOuterClass.ModelConfigList modelConfigList =
ModelServerConfigOuterClass.ModelConfigList.newBuilder()
.addConfig(modelConfig1)
.addConfig(modelConfig2)
.build();
// 4、添加到ModelConfigList到ModelServerConfig對象當中
ModelServerConfigOuterClass.ModelServerConfig modelServerConfig =
ModelServerConfigOuterClass.ModelServerConfig.newBuilder()
.setModelConfigList(modelConfigList)
.build();
// 5、構建ReloadConfigRequest並綁定ModelServerConfig對象。
ModelManagement.ReloadConfigRequest reloadConfigRequest =
ModelManagement.ReloadConfigRequest.newBuilder()
.setConfig(modelServerConfig)
.build();
// 6、構建modelServiceBlockingStub訪問句柄
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
ModelServiceGrpc.ModelServiceBlockingStub modelServiceBlockingStub =
ModelServiceGrpc.newBlockingStub(channel);
ModelManagement.ReloadConfigResponse reloadConfigResponse =
modelServiceBlockingStub.handleReloadConfigRequest(reloadConfigRequest);
System.out.println(reloadConfigResponse.getStatus().getErrorMessage());
channel.shutdownNow();
}
說明:
- 動態更新模型是一個全量的模型加載,在發佈A模型後想動態發佈B模型需要同時傳遞模型A和B的信息。
- 再次強調,需要全量更新,全量更新,全量更新!!!
在線模型預測
public static void doPredict() throws Exception {
// 1、構建feature
Map<String, Feature> featureMap = new HashMap<>();
featureMap.put("match_type", feature(""));
featureMap.put("position", feature(0.0f));
featureMap.put("brand_prefer_1d", feature(0.0f));
featureMap.put("brand_prefer_1m", feature(0.0f));
featureMap.put("brand_prefer_1w", feature(0.0f));
featureMap.put("brand_prefer_2w", feature(0.0f));
featureMap.put("browse_norm_score_1d", feature(0.0f));
featureMap.put("browse_norm_score_1w", feature(0.0f));
featureMap.put("browse_norm_score_2w", feature(0.0f));
featureMap.put("buy_norm_score_1d", feature(0.0f));
featureMap.put("buy_norm_score_1w", feature(0.0f));
featureMap.put("buy_norm_score_2w", feature(0.0f));
featureMap.put("cate1_prefer_1d", feature(0.0f));
featureMap.put("cate1_prefer_2d", feature(0.0f));
featureMap.put("cate1_prefer_1m", feature(0.0f));
featureMap.put("cate1_prefer_1w", feature(0.0f));
featureMap.put("cate1_prefer_2w", feature(0.0f));
featureMap.put("cate2_prefer_1d", feature(0.0f));
featureMap.put("cate2_prefer_1m", feature(0.0f));
featureMap.put("cate2_prefer_1w", feature(0.0f));
featureMap.put("cate2_prefer_2w", feature(0.0f));
featureMap.put("cid_prefer_1d", feature(0.0f));
featureMap.put("cid_prefer_1m", feature(0.0f));
featureMap.put("cid_prefer_1w", feature(0.0f));
featureMap.put("cid_prefer_2w", feature(0.0f));
featureMap.put("user_buy_rate_1d", feature(0.0f));
featureMap.put("user_buy_rate_2w", feature(0.0f));
featureMap.put("user_click_rate_1d", feature(0.0f));
featureMap.put("user_click_rate_1w", feature(0.0f));
Features features = Features.newBuilder().putAllFeature(featureMap).build();
Example example = Example.newBuilder().setFeatures(features).build();
// 2、構建Predict請求
Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
// 3、構建模型請求維度ModelSpec,綁定模型名和預測的簽名
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName("wdl_model");
modelSpecBuilder.setSignatureName("predict");
predictRequestBuilder.setModelSpec(modelSpecBuilder);
// 4、構建預測請求的維度信息DIM對象
TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(300).build();
TensorShapeProto shapeProto = TensorShapeProto.newBuilder().addDim(dim).build();
TensorProto.Builder tensor = TensorProto.newBuilder();
tensor.setTensorShape(shapeProto);
tensor.setDtype(DataType.DT_STRING);
// 5、批量綁定預測請求的數據
for (int i=0; i<300; i++) {
tensor.addStringVal(example.toByteString());
}
predictRequestBuilder.putInputs("examples", tensor.build());
// 6、構建PredictionServiceBlockingStub對象準備預測
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
PredictionServiceGrpc.PredictionServiceBlockingStub predictionServiceBlockingStub =
PredictionServiceGrpc.newBlockingStub(channel);
// 7、執行預測
Predict.PredictResponse predictResponse =
predictionServiceBlockingStub.predict(predictRequestBuilder.build());
// 8、解析請求結果
List<Float> floatList = predictResponse
.getOutputsOrThrow("probabilities")
.getFloatValList();
}
說明:
- TFS的RPC請求過程中設置的參數需要考慮TF模型的數據結構。
- TFS的RPC請求有同步和異步兩種方式,上述只展示同步方式。
TF模型結構
{
"model_spec": {
"name": "wdl_model",
"signature_name": "",
"version": "4"
},
"metadata": {
"signature_def": {
"signature_def": {
"predict": {
"inputs": {
"examples": {
"dtype": "DT_STRING",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
}],
"unknown_rank": false
},
"name": "input_example_tensor:0"
}
},
"outputs": {
"logistic": {
"dtype": "DT_FLOAT",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
},
{
"size": "1",
"name": ""
}
],
"unknown_rank": false
},
"name": "head/predictions/logistic:0"
},
"class_ids": {
"dtype": "DT_INT64",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
},
{
"size": "1",
"name": ""
}
],
"unknown_rank": false
},
"name": "head/predictions/ExpandDims:0"
},
"probabilities": {
"dtype": "DT_FLOAT",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
},
{
"size": "2",
"name": ""
}
],
"unknown_rank": false
},
"name": "head/predictions/probabilities:0"
},
"classes": {
"dtype": "DT_STRING",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
},
{
"size": "1",
"name": ""
}
],
"unknown_rank": false
},
"name": "head/predictions/str_classes:0"
},
"logits": {
"dtype": "DT_FLOAT",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
},
{
"size": "1",
"name": ""
}
],
"unknown_rank": false
},
"name": "add:0"
}
},
"method_name": "tensorflow/serving/predict"
},
"classification": {
"inputs": {
"inputs": {
"dtype": "DT_STRING",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
}],
"unknown_rank": false
},
"name": "input_example_tensor:0"
}
},
"outputs": {
"classes": {
"dtype": "DT_STRING",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
},
{
"size": "2",
"name": ""
}
],
"unknown_rank": false
},
"name": "head/Tile:0"
},
"scores": {
"dtype": "DT_FLOAT",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
},
{
"size": "2",
"name": ""
}
],
"unknown_rank": false
},
"name": "head/predictions/probabilities:0"
}
},
"method_name": "tensorflow/serving/classify"
},
"regression": {
"inputs": {
"inputs": {
"dtype": "DT_STRING",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
}],
"unknown_rank": false
},
"name": "input_example_tensor:0"
}
},
"outputs": {
"outputs": {
"dtype": "DT_FLOAT",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
},
{
"size": "1",
"name": ""
}
],
"unknown_rank": false
},
"name": "head/predictions/logistic:0"
}
},
"method_name": "tensorflow/serving/regress"
},
"serving_default": {
"inputs": {
"inputs": {
"dtype": "DT_STRING",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
}],
"unknown_rank": false
},
"name": "input_example_tensor:0"
}
},
"outputs": {
"classes": {
"dtype": "DT_STRING",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
},
{
"size": "2",
"name": ""
}
],
"unknown_rank": false
},
"name": "head/Tile:0"
},
"scores": {
"dtype": "DT_FLOAT",
"tensor_shape": {
"dim": [{
"size": "-1",
"name": ""
},
{
"size": "2",
"name": ""
}
],
"unknown_rank": false
},
"name": "head/predictions/probabilities:0"
}
},
"method_name": "tensorflow/serving/classify"
}
}
}
}
}