未验证 提交 a418dd44 编写于 作者: C chenjian 提交者: GitHub

Add fastdeploy server and client component (#1169)

* add backend support for fastdeploy server

* fix

* add code

* fix

* fix

* add fastdeploy server component

* add fastdeploy server and client

* add exception description

* fix

* add model repository judgement

* add component tab for fastdeploy client

* update more tasks in fastdeploy client

* sort filenames

* backup config

* noqa for autogenerated file

* add data validation

* add __init__ for package

* add calculating layout for frontend

* add alive server detection and optimize client

* add alive server detection and optimize client

* add alive server detection and optimize client

* add metrics in gradio client

* update presentation

* Change return value to None for frontend performance data when server not ready

* add get_server_config and download_pretrain_model api

* add get_server_config and download_pretrain_model api

* add unit for metric table

* add unit for metric table

* fix a bug

* add judgement pretrained model download

* add judgement pretrained model download

* add version info for frontend

* rename download model

* fix a bug

* add fastdeploy model list

* optimize for choose configuration files

* modify according to frontend need

* fix name in config to model name

* optimize for server list and alive judgement

* keep server name as string type

* optimize process judgement logic

* optimize for deleting resource files

* add rename resource file

* fix

* fix a bug

* optimize code structure

* optimize code structure

* remove chinese tips and remove fastdeploy-python in requirements
上级 b90619b9
......@@ -12,4 +12,8 @@ multiprocess
packaging
x2paddle
rarfile
onnx >= 1.6.0
\ No newline at end of file
gradio
tritonclient[all]
attrdict
psutil
onnx >= 1.6.0
# Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
# Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import gradio as gr
import numpy as np
from .http_client_manager import get_metric_data
from .http_client_manager import HttpClientManager
from .http_client_manager import metrics_table_head
from .visualizer import visualize_detection
from .visualizer import visualize_face_alignment
from .visualizer import visualize_face_detection
from .visualizer import visualize_headpose
from .visualizer import visualize_keypoint_detection
from .visualizer import visualize_matting
from .visualizer import visualize_ocr
from .visualizer import visualize_segmentation
_http_manager = HttpClientManager()
supported_tasks = {
'detection': visualize_detection,
'facedet': visualize_face_detection,
'keypointdetection': visualize_keypoint_detection,
'segmentation': visualize_segmentation,
'matting': visualize_matting,
'ocr': visualize_ocr,
'facealignment': visualize_face_alignment,
'headpose': visualize_headpose,
'unspecified': lambda x: str(x)
}
def create_gradio_client_app(): # noqa:C901
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-color: black;
background: black;
}
input[type='range'] {
accent-color: black;
}
.dark input[type='range'] {
accent-color: #dfdfdf;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
margin-left: auto;
margin-right: auto;
border-bottom-right-radius: .5rem !important;
border-bottom-left-radius: .5rem !important;
}
#gallery>div>.h-full {
min-height: 20rem;
}
.details:hover {
text-decoration: underline;
}
.gr-button {
white-space: nowrap;
}
.gr-button:focus {
border-color: rgb(147 197 253 / var(--tw-border-opacity));
outline: none;
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
--tw-border-opacity: 1;
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) \
var(--tw-ring-offset-color);
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
--tw-ring-opacity: .5;
}
.footer {
margin-bottom: 45px;
margin-top: 35px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
.prompt h4{
margin: 1.25em 0 .25em 0;
font-weight: bold;
font-size: 115%;
}
"""
block = gr.Blocks(css=css)
with block:
gr.HTML("""
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div
style="
display: inline-flex;
gap: 0.8rem;
font-size: 1.75rem;
justify-content: center;
"
>
<h1>
FastDeploy Client
</h1>
</div>
<p font-size: 94%">
The client is used for creating requests to fastdeploy server.
</p>
</div>
""")
with gr.Group():
with gr.Box():
with gr.Column():
with gr.Row():
server_addr_text = gr.Textbox(
label="服务ip",
show_label=True,
max_lines=1,
placeholder="localhost",
)
server_http_port_text = gr.Textbox(
label="推理服务端口",
show_label=True,
max_lines=1,
placeholder="8000",
)
server_metric_port_text = gr.Textbox(
label="性能服务端口",
show_label=True,
max_lines=1,
placeholder="8002",
)
with gr.Row():
model_name_text = gr.Textbox(
label="模型名称",
show_label=True,
max_lines=1,
placeholder="yolov5",
)
model_version_text = gr.Textbox(
label="模型版本",
show_label=True,
max_lines=1,
placeholder="1",
)
with gr.Box():
with gr.Tab("组件形式"):
check_button = gr.Button("获取模型输入输出")
component_format_column = gr.Column(visible=False)
with component_format_column:
task_radio = gr.Radio(
choices=list(supported_tasks.keys()),
value='unspecified',
label='任务类型',
visible=True)
gr.Markdown("根据模型需要,挑选文本框或者图像框进行输入")
with gr.Row():
with gr.Column():
gr.Markdown("模型输入")
input_accordions = []
input_name_texts = []
input_images = []
input_texts = []
for i in range(6):
accordion = gr.Accordion(
"输入变量 {}".format(i),
open=True,
visible=False)
with accordion:
input_name_text = gr.Textbox(
label="变量名", interactive=False)
input_image = gr.Image(type='numpy')
input_text = gr.Textbox(
label="文本框", max_lines=1000)
input_accordions.append(accordion)
input_name_texts.append(input_name_text)
input_images.append(input_image)
input_texts.append(input_text)
with gr.Column():
gr.Markdown("模型输出")
output_accordions = []
output_name_texts = []
output_images = []
output_texts = []
for i in range(6):
accordion = gr.Accordion(
"输出变量 {}".format(i),
open=True,
visible=False)
with accordion:
output_name_text = gr.Textbox(
label="变量名", interactive=False)
output_text = gr.Textbox(
label="服务返回的原数据",
interactive=False,
show_label=True)
output_image = gr.Image(
interactive=False)
output_accordions.append(accordion)
output_name_texts.append(output_name_text)
output_images.append(output_image)
output_texts.append(output_text)
component_submit_button = gr.Button("提交请求")
with gr.Tab("原始形式"):
gr.Markdown("模型输入")
raw_payload_text = gr.Textbox(
label="负载数据", max_lines=10000)
with gr.Column():
gr.Markdown("输出")
output_raw_text = gr.Textbox(
label="服务返回的原始数据", interactive=False)
raw_submit_button = gr.Button("提交请求")
with gr.Box():
with gr.Column():
gr.Markdown("服务性能统计(每次提交请求会自动更新数据,您也可以手动点击更新)")
output_html_table = gr.HTML(
label="metrics",
interactive=False,
show_label=False,
value=metrics_table_head.format('', ''))
update_metric_button = gr.Button("更新统计数据")
status_text = gr.Textbox(
label="status",
show_label=True,
max_lines=1,
interactive=False)
all_input_output_components = input_accordions + input_name_texts + input_images + \
input_texts + output_accordions + output_name_texts + output_images + output_texts
def get_input_output_name(server_ip, server_port, model_name,
model_version):
try:
server_addr = server_ip + ':' + server_port
input_metas, output_metas = _http_manager.get_model_meta(
server_addr, model_name, model_version)
except Exception as e:
return {status_text: str(e)}
results = {
component: None
for component in all_input_output_components
}
results[component_format_column] = gr.update(visible=True)
# results[check_button] = gr.update(visible=False)
for input_accordio in input_accordions:
results[input_accordio] = gr.update(visible=False)
for output_accordio in output_accordions:
results[output_accordio] = gr.update(visible=False)
results[status_text] = 'GetInputOutputName Successful'
for i, input_meta in enumerate(input_metas):
results[input_accordions[i]] = gr.update(visible=True)
results[input_name_texts[i]] = input_meta['name']
for i, output_meta in enumerate(output_metas):
results[output_accordions[i]] = gr.update(visible=True)
results[output_name_texts[i]] = output_meta['name']
return results
def component_inference(*args):
server_ip = args[0]
http_port = args[1]
metric_port = args[2]
model_name = args[3]
model_version = args[4]
names = args[5:5 + len(input_name_texts)]
images = args[5 + len(input_name_texts):5 + len(input_name_texts) +
len(input_images)]
texts = args[5 + len(input_name_texts) + len(input_images):5 +
len(input_name_texts) + len(input_images) +
len(input_texts)]
task_type = args[-1]
server_addr = server_ip + ':' + http_port
if server_ip and http_port and model_name and model_version:
inputs = {}
for i, input_name in enumerate(names):
if input_name:
if images[i] is not None:
inputs[input_name] = np.array([images[i]])
if texts[i]:
inputs[input_name] = np.array(
[[texts[i].encode('utf-8')]], dtype=np.object_)
try:
infer_results = _http_manager.infer(
server_addr, model_name, model_version, inputs)
results = {status_text: 'Inference Successful'}
for i, (output_name,
data) in enumerate(infer_results.items()):
results[output_name_texts[i]] = output_name
results[output_texts[i]] = str(data)
if task_type != 'unspecified':
try:
results[output_images[i]] = supported_tasks[
task_type](images[0], data)
except Exception:
results[output_images[i]] = None
if metric_port:
html_table = get_metric_data(server_ip, metric_port)
results[output_html_table] = html_table
return results
except Exception as e:
return {status_text: 'Error: {}'.format(e)}
else:
return {
status_text:
'Please input server addr, model name and model version.'
}
def raw_inference(*args):
server_ip = args[0]
http_port = args[1]
metric_port = args[2]
model_name = args[3]
model_version = args[4]
payload_text = args[5]
server_addr = server_ip + ':' + http_port
try:
result = _http_manager.raw_infer(server_addr, model_name,
model_version, payload_text)
results = {
status_text: 'Get response from server',
output_raw_text: result
}
if server_ip and metric_port:
html_table = get_metric_data(server_ip, metric_port)
results[output_html_table] = html_table
return results
except Exception as e:
return {status_text: 'Error: {}'.format(e)}
def update_metric(server_ip, metrics_port):
if server_ip and metrics_port:
try:
html_table = get_metric_data(server_ip, metrics_port)
return {
output_html_table: html_table,
status_text: "Successfully update metrics."
}
except Exception as e:
return {status_text: 'Error: {}'.format(e)}
else:
return {
status_text: 'Please input server ip and metrics_port.'
}
check_button.click(
fn=get_input_output_name,
inputs=[
server_addr_text, server_http_port_text, model_name_text,
model_version_text
],
outputs=[
*all_input_output_components, check_button,
component_format_column, status_text
])
component_submit_button.click(
fn=component_inference,
inputs=[
server_addr_text, server_http_port_text,
server_metric_port_text, model_name_text, model_version_text,
*input_name_texts, *input_images, *input_texts, task_radio
],
outputs=[
*output_name_texts, *output_images, *output_texts, status_text,
output_html_table
])
raw_submit_button.click(
fn=raw_inference,
inputs=[
server_addr_text, server_http_port_text,
server_metric_port_text, model_name_text, model_version_text,
raw_payload_text
],
outputs=[output_raw_text, status_text, output_html_table])
update_metric_button.click(
fn=update_metric,
inputs=[server_addr_text, server_metric_port_text],
outputs=[output_html_table, status_text])
return block
# Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import json
import re
import numpy as np
import requests
import tritonclient.http as httpclient
from attrdict import AttrDict
from tritonclient.utils import InferenceServerException
def convert_http_metadata_config(metadata):
metadata = AttrDict(metadata)
return metadata
def prepare_request(inputs_meta, inputs_data, outputs_meta):
'''
inputs_meta: inputs meta information from model. name: info
inputs_data: users input data. name: data
'''
# Set the input data
inputs = []
for input_dict in inputs_meta:
input_name = input_dict['name']
if input_name not in inputs_data:
raise RuntimeError(
'Error: input name {} required for model not existed.'.format(
input_name))
if input_dict['datatype'] == 'FP32':
inputs_data[input_name] = inputs_data[input_name].astype(
np.float32
) / 255 # image data returned by gradio is uint8, convert to fp32
if len(input_dict['shape']
) == 3 and input_dict['shape'][0] == 3: # NCHW
inputs_data[input_name] = inputs_data[input_name][0].transpose(
2, 0, 1)
elif len(input_dict['shape']
) == 4 and input_dict['shape'][1] == 3: # NCHW
inputs_data[input_name] = inputs_data[input_name].transpose(
0, 3, 1, 2)
infer_input = httpclient.InferInput(
input_name, inputs_data[input_name].shape, input_dict['datatype'])
infer_input.set_data_from_numpy(inputs_data[input_name])
inputs.append(infer_input)
outputs = []
for output_dict in outputs_meta:
infer_output = httpclient.InferRequestedOutput(output_dict.name)
outputs.append(infer_output)
return inputs, outputs
metrics_table_head = """
<style>
table, th {{
border:0.1px solid black;
}}
</style>
<div>
<table style="width:100%">
<tr>
<th rowspan="2">模型名称</th>
<th colspan="4">执行统计</th>
<th colspan="5">延迟统计</th>
</tr>
<tr>
<th>请求处理成功数</th>
<th>请求处理失败数</th>
<th>推理batch数</th>
<th>推理样本数</th>
<th>请求处理时间(ms)</th>
<th>任务队列等待时间(ms)</th>
<th>输入处理时间(ms)</th>
<th>模型推理时间(ms)</th>
<th>输出处理时间(ms)</th>
</tr>
{}
</table>
</div>
<br>
<br>
<br>
<br>
<br>
<div>
<table style="width:100%">
<tr>
<th rowspan="2">GPU</th>
<th colspan="4">性能指标</th>
<th colspan="2">显存</th>
</tr>
<tr>
<th>利用率(%)</th>
<th>功率(W)</th>
<th>功率限制(W)</th>
<th>耗电量(W)</th>
<th>总量(GB)</th>
<th>已使用(GB)</th>
</tr>
{}
</table>
</div>
"""
def get_metric_data(server_addr, metric_port): # noqa:C901
'''
Get metrics data from fastdeploy server, and transform it into html table.
Args:
server_addr(str): fastdeployserver ip address
metric_port(int): fastdeployserver metrics port
Returns:
htmltable(str): html table to show metrics data
'''
model_table = {}
gpu_table = {}
metric_column_name = {
"Model": {
"nv_inference_request_success", "nv_inference_request_failure",
"nv_inference_count", "nv_inference_exec_count",
"nv_inference_request_duration_us",
"nv_inference_queue_duration_us",
"nv_inference_compute_input_duration_us",
"nv_inference_compute_infer_duration_us",
"nv_inference_compute_output_duration_us"
},
"GPU": {
"nv_gpu_power_usage", "nv_gpu_power_limit",
"nv_energy_consumption", "nv_gpu_utilization",
"nv_gpu_memory_total_bytes", "nv_gpu_memory_used_bytes"
},
"CPU": {
"nv_cpu_utilization", "nv_cpu_memory_total_bytes",
"nv_cpu_memory_used_bytes"
}
}
try:
res = requests.get("http://{}:{}/metrics".format(
server_addr, metric_port))
except Exception:
return metrics_table_head.format('', '')
metric_content = res.text
for content in metric_content.split('\n'):
if content.startswith('#'):
continue
else:
res = re.match(r'(\w+){(.*)} (\w+)',
content) # match output by server metrics interface
if not res:
continue
metric_name = res.group(1)
model = res.group(2)
value = res.group(3)
infos = {}
for info in model.split(','):
k, v = info.split('=')
v = v.strip('"')
infos[k] = v
if metric_name in [
"nv_inference_request_duration_us",
"nv_inference_queue_duration_us",
"nv_inference_compute_input_duration_us",
"nv_inference_compute_infer_duration_us",
"nv_inference_compute_output_duration_us"
]:
value = str(float(value) / 1000)
elif metric_name in [
"nv_gpu_memory_total_bytes", "nv_gpu_memory_used_bytes"
]:
value = str(float(value) / 1024 / 1024 / 1024)
for key, metric_names in metric_column_name.items():
if metric_name in metric_names:
if key == 'Model':
model_name = infos['model']
if model_name not in model_table:
model_table[model_name] = {}
model_table[model_name][metric_name] = value
elif key == 'GPU':
gpu_name = infos['gpu_uuid']
if gpu_name not in gpu_table:
gpu_table[gpu_name] = {}
gpu_table[gpu_name][metric_name] = value
elif key == 'CPU':
pass
model_data_list = []
gpu_data_list = []
model_data_metric_names = [
"nv_inference_request_success", "nv_inference_request_failure",
"nv_inference_exec_count", "nv_inference_count",
"nv_inference_request_duration_us", "nv_inference_queue_duration_us",
"nv_inference_compute_input_duration_us",
"nv_inference_compute_infer_duration_us",
"nv_inference_compute_output_duration_us"
]
gpu_data_metric_names = [
"nv_gpu_utilization", "nv_gpu_power_usage", "nv_gpu_power_limit",
"nv_energy_consumption", "nv_gpu_memory_total_bytes",
"nv_gpu_memory_used_bytes"
]
for k, v in model_table.items():
data = []
data.append(k)
for data_metric in model_data_metric_names:
data.append(v[data_metric])
model_data_list.append(data)
for k, v in gpu_table.items():
data = []
data.append(k)
for data_metric in gpu_data_metric_names:
data.append(v[data_metric])
gpu_data_list.append(data)
model_data = '\n'.join([
"<tr>" + '\n'.join(["<td>" + item + "</td>"
for item in data]) + "</tr>"
for data in model_data_list
])
gpu_data = '\n'.join([
"<tr>" + '\n'.join(["<td>" + item + "</td>"
for item in data]) + "</tr>"
for data in gpu_data_list
])
return metrics_table_head.format(model_data, gpu_data)
class HttpClientManager:
def __init__(self):
self.clients = {} # server url: httpclient
def _create_client(self, server_url):
if server_url in self.clients:
return self.clients[server_url]
try:
fastdeploy_client = httpclient.InferenceServerClient(server_url)
self.clients[server_url] = fastdeploy_client
return fastdeploy_client
except Exception:
raise RuntimeError(
'Can not connect to server {}, please check your \
server address'.format(server_url))
def infer(self, server_url, model_name, model_version, inputs):
fastdeploy_client = self._create_client(server_url)
input_metadata, output_metadata = self.get_model_meta(
server_url, model_name, model_version)
inputs, outputs = prepare_request(input_metadata, inputs,
output_metadata)
response = fastdeploy_client.infer(
model_name, inputs, model_version=model_version, outputs=outputs)
results = {}
for output in output_metadata:
result = response.as_numpy(output.name) # datatype: numpy
if output.datatype == 'BYTES': # datatype: bytes
try:
value = result
if len(result.shape) == 1:
value = result[0]
elif len(result.shape) == 2:
value = result[0][0]
elif len(result.shape) == 3:
value = result[0][0][0]
result = json.loads(value) # datatype: json
except Exception:
pass
else:
result = result[0]
results[output.name] = result
return results
def raw_infer(self, server_url, model_name, model_version, raw_input):
url = 'http://{}/v2/models/{}/versions/{}/infer'.format(
server_url, model_name, model_version)
res = requests.post(url, data=json.dumps(json.loads(raw_input)))
return json.dumps(res.json())
def get_model_meta(self, server_url, model_name, model_version):
fastdeploy_client = self._create_client(server_url)
try:
model_metadata = fastdeploy_client.get_model_metadata(
model_name=model_name, model_version=model_version)
except InferenceServerException as e:
raise RuntimeError("Failed to retrieve the metadata: " + str(e))
model_metadata = convert_http_metadata_config(model_metadata)
input_metadata = model_metadata.inputs
output_metadata = model_metadata.outputs
return input_metadata, output_metadata
# Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import numpy as np
__all__ = [
'visualize_detection', 'visualize_keypoint_detection',
'visualize_face_detection', 'visualize_face_alignment',
'visualize_segmentation', 'visualize_matting', 'visualize_ocr',
'visualize_headpose'
]
def visualize_detection(image, data):
try:
import fastdeploy as fd
except Exception:
raise RuntimeError(
"fastdeploy is required for visualizing results,please refer to \
https://github.com/PaddlePaddle/FastDeploy to install fastdeploy")
boxes = np.array(data['boxes'])
scores = np.array(data['scores'])
label_ids = np.array(data['label_ids'])
masks = np.array(data['masks'])
contain_masks = data['contain_masks']
detection_result = fd.C.vision.DetectionResult()
detection_result.boxes = boxes
detection_result.scores = scores
detection_result.label_ids = label_ids
detection_result.masks = masks
detection_result.contain_masks = contain_masks
result = fd.vision.vis_detection(image, detection_result)
return result
def visualize_keypoint_detection(image, data):
try:
import fastdeploy as fd
except Exception:
raise RuntimeError(
"fastdeploy is required for visualizing results,please refer to \
https://github.com/PaddlePaddle/FastDeploy to install fastdeploy")
keypoints = np.array(data['keypoints'])
scores = np.array(data['scores'])
num_joints = np.array(data['num_joints'])
detection_result = fd.C.vision.KeyPointDetectionResult()
detection_result.keypoints = keypoints
detection_result.scores = scores
detection_result.num_joints = num_joints
result = fd.vision.vis_keypoint_detection(image, detection_result)
return result
def visualize_face_detection(image, data):
try:
import fastdeploy as fd
except Exception:
raise RuntimeError(
"fastdeploy is required for visualizing results,please refer to \
https://github.com/PaddlePaddle/FastDeploy to install fastdeploy")
data = np.array(data['data'])
scores = np.array(data['scores'])
landmarks = np.array(data['landmarks'])
landmarks_per_face = data['landmarks_per_face']
detection_result = fd.C.vision.FaceDetectionResult()
detection_result.data = data
detection_result.scores = scores
detection_result.landmarks = landmarks
detection_result.landmarks_per_face = landmarks_per_face
result = fd.vision.vis_face_detection(image, detection_result)
return result
def visualize_face_alignment(image, data):
try:
import fastdeploy as fd
except Exception:
raise RuntimeError(
"fastdeploy is required for visualizing results,please refer to \
https://github.com/PaddlePaddle/FastDeploy to install fastdeploy")
landmarks = np.array(data['landmarks'])
facealignment_result = fd.C.vision.FaceAlignmentResult()
facealignment_result.landmarks = landmarks
result = fd.vision.vis_face_alignment(image, facealignment_result)
return result
def visualize_segmentation(image, data):
try:
import fastdeploy as fd
except Exception:
raise RuntimeError(
"fastdeploy is required for visualizing results,please refer to \
https://github.com/PaddlePaddle/FastDeploy to install fastdeploy")
label_ids = np.array(data['label_ids'])
score_map = np.array(data['score_map'])
shape = np.array(data['shape'])
segmentation_result = fd.C.vision.SegmentationResult()
segmentation_result.shape = shape
segmentation_result.score_map = score_map
segmentation_result.label_ids = label_ids
result = fd.vision.vis_segmentation(image, segmentation_result)
return result
def visualize_matting(image, data):
try:
import fastdeploy as fd
except Exception:
raise RuntimeError(
"fastdeploy is required for visualizing results,please refer to \
https://github.com/PaddlePaddle/FastDeploy to install fastdeploy")
alpha = np.array(data['alpha'])
foreground = np.array(data['foreground'])
contain_foreground = data['contain_foreground']
shape = np.array(data['shape'])
matting_result = fd.C.vision.MattingResult()
matting_result.alpha = alpha
matting_result.foreground = foreground
matting_result.contain_foreground = contain_foreground
matting_result.shape = shape
result = fd.vision.vis_matting(image, matting_result)
return result
def visualize_ocr(image, data):
try:
import fastdeploy as fd
except Exception:
raise RuntimeError(
"fastdeploy is required for visualizing results,please refer to \
https://github.com/PaddlePaddle/FastDeploy to install fastdeploy")
boxes = np.array(data['boxes'])
text = np.array(data['text'])
rec_scores = np.array(data['rec_scores'])
cls_scores = np.array(data['cls_scores'])
cls_labels = data['cls_labels']
ocr_result = fd.C.vision.OCRResult()
ocr_result.boxes = boxes
ocr_result.text = text
ocr_result.rec_scores = rec_scores
ocr_result.cls_scores = cls_scores
ocr_result.cls_labels = cls_labels
result = fd.vision.vis_ppocr(image, ocr_result)
return result
def visualize_headpose(image, data):
try:
import fastdeploy as fd
except Exception:
raise RuntimeError(
"fastdeploy is required for visualizing results,please refer to \
https://github.com/PaddlePaddle/FastDeploy to install fastdeploy")
euler_angles = np.array(data['euler_angles'])
headpose_result = fd.C.vision.HeadPoseResult()
headpose_result.euler_angles = euler_angles
result = fd.vision.vis_headpose(image, headpose_result)
return result
此差异已折叠。
# Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import datetime
import json
import os
import re
import shutil
import socket
import time
from multiprocessing import Process
from pathlib import Path
import requests
from .fastdeploy_client.client_app import create_gradio_client_app
from .fastdeploy_lib import analyse_config
from .fastdeploy_lib import check_process_zombie
from .fastdeploy_lib import copy_config_file_to_default_config
from .fastdeploy_lib import delete_files_for_process
from .fastdeploy_lib import exchange_format_to_original_format
from .fastdeploy_lib import generate_metric_table
from .fastdeploy_lib import get_alive_fastdeploy_servers
from .fastdeploy_lib import get_config_filenames_for_one_model
from .fastdeploy_lib import get_config_for_one_model
from .fastdeploy_lib import get_process_model_configuration
from .fastdeploy_lib import get_process_output
from .fastdeploy_lib import get_start_arguments
from .fastdeploy_lib import json2pbtxt
from .fastdeploy_lib import kill_process
from .fastdeploy_lib import launch_process
from .fastdeploy_lib import mark_pid_for_dead_process
from .fastdeploy_lib import original_format_to_exchange_format
from .fastdeploy_lib import validate_data
from visualdl.server.api import gen_result
from visualdl.server.api import result
from visualdl.utils.dir import FASTDEPLOYSERVER_PATH
class FastDeployServerApi(object):
def __init__(self):
self.root_dir = Path(os.getcwd())
self.opened_servers = {
} # Use to store the opened server process pid and process itself
self.client_port = None
@result()
def get_directory(self, cur_dir):
if self.root_dir not in Path(os.path.abspath(cur_dir)).parents:
cur_dir = '.'
cur_dir, sub_dirs, filenames = os.walk(cur_dir).send(None)
if Path(self.root_dir) != Path(os.path.abspath(cur_dir)):
sub_dirs.append('..')
sub_dirs = sorted(sub_dirs)
directorys = {
'parent_dir':
os.path.relpath(Path(os.path.abspath(cur_dir)), self.root_dir),
'sub_dir':
sub_dirs
}
return directorys
@result()
def get_config(self, cur_dir):
all_model_configs, all_model_versions = analyse_config(cur_dir)
return original_format_to_exchange_format(all_model_configs,
all_model_versions)
@result()
def config_update(self, cur_dir, model_name, config, config_filename):
config = json.loads(config)
all_models = exchange_format_to_original_format(config)
model_dir = os.path.join(os.path.abspath(cur_dir), model_name)
filtered_config = validate_data(all_models[model_name])
text_proto = json2pbtxt(json.dumps(filtered_config))
# backup user's config data first, when data corrupted by front-end, we still can recovery data
# backup config filename: {original_name}_vdlbackup_{datetime}.pbtxt
# backup config can only used to restore config.pbtxt
if 'vdlbackup' in config_filename:
raise RuntimeError(
"Backup config file is not permitted to update.")
basename = os.path.splitext(config_filename)[0]
shutil.copy(
os.path.join(model_dir, config_filename),
os.path.join(
model_dir, '{}_vdlbackup_{}.pbtxt'.format(
basename,
datetime.datetime.now().isoformat())))
with open(os.path.join(model_dir, config_filename), 'w') as f:
f.write(text_proto)
return
@result()
def start_server(self, configs):
configs = json.loads(configs)
process = launch_process(configs)
if process.poll() is not None:
raise RuntimeError(
"Failed to launch fastdeployserver,please check fastdeployserver is installed in environment."
)
server_name = configs['server-name'] if configs[
'server-name'] else str(process.pid)
self.opened_servers[server_name] = process
return server_name
@result()
def stop_server(self, server_id):
if server_id in self.opened_servers: # check if server_id in self.opened_servers
kill_process(self.opened_servers[server_id])
del self.opened_servers[server_id]
elif server_id in set(
os.listdir(FASTDEPLOYSERVER_PATH)): # check if server_id in
# FASTDEPLOYSERVER_PATH(may be launched by other vdl app instance by gunicorn)
kill_process(server_id)
delete_files_for_process(server_id)
self._poll_zombie_process()
@result('text/plain')
def get_server_output(self, server_id, length):
length = int(length)
if server_id in self.opened_servers: # check if server_id in self.opened_servers
return get_process_output(server_id, length)
elif str(server_id) in set(
os.listdir(FASTDEPLOYSERVER_PATH)): # check if server_id in
# FASTDEPLOYSERVER_PATH(may be launched by other vdl app instance by gunicorn)
return get_process_output(server_id, length)
else:
return
@result()
def get_server_metric(self, server_id):
args = get_start_arguments(server_id)
host = 'localhost'
port = args.get('metrics-port', 8002)
return generate_metric_table(host, port)
@result()
def get_server_list(self):
return get_alive_fastdeploy_servers()
@result()
def check_server_alive(self, server_id):
self._poll_zombie_process()
if check_process_zombie(server_id) is True:
raise RuntimeError(
"Server {} is down due to exception or killed,please check the reason according to the log, \
then close this server.".format(server_id))
return
@result()
def get_server_config(self, server_id):
return get_process_model_configuration(server_id)
@result()
def get_pretrain_model_list(self):
'''
Get all available fastdeploy models from hub server.
'''
res = requests.get(
'http://paddlepaddle.org.cn/paddlehub/fastdeploy_listmodels')
result = res.json()
if result['status'] != 0:
raise RuntimeError(
"Failed to get pre-trained model list from hub server.")
else:
data = result['data']
model_list = {}
for category, models in data.items():
if category not in model_list:
model_list[category] = set()
for model in models:
model_list[category].add(model['name'])
# adapt data format for frontend
models_info = []
for category, model_names in model_list.items():
models_info.append({
"value": category,
"label": category,
"children": []
})
for model_name in sorted(model_names):
models_info[-1]["children"].append({
"value": model_name,
"label": model_name
})
return models_info
@result()
def download_pretrain_model(self, cur_dir, model_name, version,
pretrain_model_name):
version_resource_dir = os.path.join(
os.path.abspath(cur_dir), model_name, version)
try:
import fastdeploy as fd
except Exception:
raise RuntimeError(
"fastdeploy is required for visualizing results,please refer to \
https://github.com/PaddlePaddle/FastDeploy to install fastdeploy")
model_path = fd.download_model(
name=pretrain_model_name, path=version_resource_dir)
if model_path:
if '.onnx' in model_path:
shutil.move(
model_path,
os.path.join(os.path.dirname(model_path), 'model.onnx'))
else:
for filename in os.listdir(model_path):
if '.pdmodel' in filename or '.pdiparams' in filename:
shutil.move(
os.path.join(model_path, filename),
os.path.join(
os.path.dirname(model_path), 'model{}'.format(
os.path.splitext(filename)[1])))
else:
shutil.move(
os.path.join(model_path, filename),
os.path.join(
os.path.dirname(model_path), filename))
shutil.rmtree(model_path)
version_info_for_frontend = []
for version_name in os.listdir(os.path.join(cur_dir, model_name)):
if re.match(
r'\d+',
version_name): # version directory consists of numbers
version_filenames_dict_for_frontend = {}
version_filenames_dict_for_frontend['title'] = version_name
version_filenames_dict_for_frontend['key'] = version_name
version_filenames_dict_for_frontend['children'] = []
for filename in os.listdir(
os.path.join(cur_dir, model_name, version_name)):
version_filenames_dict_for_frontend['children'].append(
{
'title': filename,
'key': filename
})
version_info_for_frontend.append(
version_filenames_dict_for_frontend)
return version_info_for_frontend
else:
raise RuntimeError(
"Failed to download pre-trained model {}.".format(
pretrain_model_name))
@result()
def get_config_for_model(self, cur_dir, name, config_filename):
return get_config_for_one_model(cur_dir, name, config_filename)
@result()
def get_config_filenames_for_model(self, cur_dir, name):
return get_config_filenames_for_one_model(cur_dir, name)
@result()
def delete_config_for_model(self, cur_dir, name, config_filename):
if self.root_dir not in Path(
os.path.abspath(cur_dir)
).parents: # should prevent user remove files outside model-repository
raise RuntimeError(
'Failed to delete config file, please check filepath.')
if os.path.exists(os.path.join(cur_dir, name, config_filename)):
os.remove(os.path.join(cur_dir, name, config_filename))
return get_config_filenames_for_one_model(cur_dir, name)
@result()
def set_default_config_for_model(self, cur_dir, name, config_filename):
model_dir = os.path.join(os.path.abspath(cur_dir), name)
# backup config.pbtxt to config_vdlbackup_{datetime}.pbtxt
if os.path.exists(os.path.join(model_dir, 'config.pbtxt')):
shutil.copy(
os.path.join(model_dir, 'config.pbtxt'),
os.path.join(
model_dir, 'config_vdlbackup_{}.pbtxt'.format(
datetime.datetime.now().isoformat())))
if config_filename != 'config.pbtxt':
copy_config_file_to_default_config(model_dir, config_filename)
return
@result()
def delete_resource_for_model(self, cur_dir, model_name, version,
resource_filename):
if self.root_dir not in Path(
os.path.abspath(cur_dir)
).parents: # should prevent user remove files outside model-repository
raise RuntimeError(
'Failed to delete resource file, please check filepath.')
resource_path = os.path.join(
os.path.abspath(cur_dir), model_name, version, resource_filename)
if os.path.exists(resource_path):
os.remove(resource_path)
version_info_for_frontend = []
for version_name in os.listdir(os.path.join(cur_dir, model_name)):
if re.match(r'\d+',
version_name): # version directory consists of numbers
version_filenames_dict_for_frontend = {}
version_filenames_dict_for_frontend['title'] = version_name
version_filenames_dict_for_frontend['key'] = version_name
version_filenames_dict_for_frontend['children'] = []
for filename in os.listdir(
os.path.join(cur_dir, model_name, version_name)):
version_filenames_dict_for_frontend['children'].append({
'title':
filename,
'key':
filename
})
version_info_for_frontend.append(
version_filenames_dict_for_frontend)
return version_info_for_frontend
@result()
def rename_resource_for_model(self, cur_dir, model_name, version,
resource_filename, new_filename):
if self.root_dir not in Path(
os.path.abspath(cur_dir)
).parents: # should prevent user remove files outside model-repository
raise RuntimeError(
'Failed to rename resource file, please check filepath.')
resource_path = os.path.join(
os.path.abspath(cur_dir), model_name, version, resource_filename)
new_file_path = os.path.join(
os.path.abspath(cur_dir), model_name, version, new_filename)
if os.path.exists(resource_path):
shutil.move(resource_path, new_file_path)
version_info_for_frontend = []
for version_name in os.listdir(os.path.join(cur_dir, model_name)):
if re.match(r'\d+',
version_name): # version directory consists of numbers
version_filenames_dict_for_frontend = {}
version_filenames_dict_for_frontend['title'] = version_name
version_filenames_dict_for_frontend['key'] = version_name
version_filenames_dict_for_frontend['children'] = []
for filename in os.listdir(
os.path.join(cur_dir, model_name, version_name)):
version_filenames_dict_for_frontend['children'].append({
'title':
filename,
'key':
filename
})
version_info_for_frontend.append(
version_filenames_dict_for_frontend)
return version_info_for_frontend
def create_fastdeploy_client(self):
if self.client_port is None:
def get_free_tcp_port():
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
tcp.bind(('localhost', 0))
addr, port = tcp.getsockname()
tcp.close()
return port
self.client_port = get_free_tcp_port()
app = create_gradio_client_app()
thread = Process(
target=app.launch, kwargs={'server_port': self.client_port})
thread.start()
def check_alive():
while True:
try:
requests.get('http://localhost:{}/'.format(
self.client_port))
break
except Exception:
time.sleep(1)
check_alive()
return self.client_port
def _poll_zombie_process(self):
# check if there are servers killed by other vdl app instance and become zoombie
should_delete = []
for server_id, process in self.opened_servers.items():
if process.poll() is not None:
mark_pid_for_dead_process(server_id)
should_delete.append(server_id)
for server_id in should_delete:
del self.opened_servers[server_id]
def create_fastdeploy_api_call():
api = FastDeployServerApi()
routes = {
'get_directory': (api.get_directory, ['dir']),
'config_update': (api.config_update,
['dir', 'name', 'config', 'config_filename']),
'get_config': (api.get_config, ['dir']),
'get_config_filenames_for_model': (api.get_config_filenames_for_model,
['dir', 'name']),
'get_config_for_model': (api.get_config_for_model,
['dir', 'name', 'config_filename']),
'set_default_config_for_model': (api.set_default_config_for_model,
['dir', 'name', 'config_filename']),
'delete_config_for_model': (api.delete_config_for_model,
['dir', 'name', 'config_filename']),
'start_server': (api.start_server, ['config']),
'stop_server': (api.stop_server, ['server_id']),
'get_server_output': (api.get_server_output, ['server_id', 'length']),
'create_fastdeploy_client': (api.create_fastdeploy_client, []),
'get_server_list': (api.get_server_list, []),
'get_server_metric': (api.get_server_metric, ['server_id']),
'get_server_config': (api.get_server_config, ['server_id']),
'get_pretrain_model_list': (api.get_pretrain_model_list, []),
'check_server_alive': (api.check_server_alive, ['server_id']),
'download_pretrain_model':
(api.download_pretrain_model,
['dir', 'name', 'version', 'pretrain_model_name']),
'delete_resource_for_model':
(api.delete_resource_for_model,
['dir', 'name', 'version', 'resource_filename']),
'rename_resource_for_model': (api.rename_resource_for_model, [
'dir', 'name', 'version', 'resource_filename', 'new_filename'
])
}
def call(path: str, args):
route = routes.get(path)
if not route:
return json.dumps(gen_result(
status=1, msg='api not found')), 'application/json', None
method, call_arg_names = route
call_args = [args.get(name) for name in call_arg_names]
return method(*call_args)
return call
# Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
此差异已折叠。
此差异已折叠。
......@@ -417,7 +417,10 @@ def get_component_tabs(*apis, vdl_args, request_args):
all_tabs.update(api('component_tabs', request_args))
all_tabs.add('static_graph')
else:
return ['static_graph', 'x2paddle', 'fastdeploy_server']
return [
'static_graph', 'x2paddle', 'fastdeploy_server',
'fastdeploy_client'
]
return list(all_tabs)
......
......@@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import json
import multiprocessing
import os
import re
import sys
import threading
import time
import urllib
import webbrowser
import requests
......@@ -32,6 +34,8 @@ from flask_babel import Babel
import visualdl.server
from visualdl import __version__
from visualdl.component.inference.fastdeploy_lib import get_start_arguments
from visualdl.component.inference.fastdeploy_server import create_fastdeploy_api_call
from visualdl.component.inference.model_convert_server import create_model_convert_api_call
from visualdl.component.profiler.profiler_server import create_profiler_api_call
from visualdl.server.api import create_api_call
......@@ -71,6 +75,7 @@ def create_app(args): # noqa: C901
api_call = create_api_call(args.logdir, args.model, args.cache_timeout)
profiler_api_call = create_profiler_api_call(args.logdir)
inference_api_call = create_model_convert_api_call()
fastdeploy_api_call = create_fastdeploy_api_call()
if args.telemetry:
update_util.PbUpdater(args.product).start()
......@@ -153,6 +158,141 @@ def create_app(args): # noqa: C901
return make_response(
Response(data, mimetype=mimetype, headers=headers))
@app.route(api_path + '/fastdeploy/<path:method>', methods=["GET", "POST"])
def serve_fastdeploy_api(method):
if request.method == 'POST':
data, mimetype, headers = fastdeploy_api_call(method, request.form)
else:
data, mimetype, headers = fastdeploy_api_call(method, request.args)
return make_response(
Response(data, mimetype=mimetype, headers=headers))
@app.route(
api_path + '/fastdeploy/fastdeploy_client', methods=["GET", "POST"])
def serve_fastdeploy_create_fastdeploy_client():
try:
if request.method == 'POST':
fastdeploy_api_call('create_fastdeploy_client', request.form)
request_args = request.form
else:
fastdeploy_api_call('create_fastdeploy_client', request.args)
request_args = request.args
except Exception as e:
error_msg = '{}'.format(e)
return make_response(error_msg)
args = urllib.parse.urlencode(request_args)
if args:
return redirect(
api_path + "/fastdeploy/fastdeploy_client/app?{}".format(args),
code=302)
return redirect(
api_path + "/fastdeploy/fastdeploy_client/app", code=302)
@app.route(
api_path + "/fastdeploy/fastdeploy_client/<path:path>",
methods=["GET", "POST"])
def request_fastdeploy_create_fastdeploy_client_app(path: str):
'''
Gradio app server url interface. We route urls for gradio app to gradio server.
Args:
path(str): All resource path from gradio server.
Returns:
Any thing from gradio server.
'''
if request.method == 'POST':
port = fastdeploy_api_call('create_fastdeploy_client',
request.form)
request_args = request.form
else:
port = fastdeploy_api_call('create_fastdeploy_client',
request.args)
request_args = request.args
if path == 'app':
proxy_url = request.url.replace(
request.host_url.rstrip('/') + api_path +
'/fastdeploy/fastdeploy_client/app',
'http://localhost:{}/'.format(port))
else:
proxy_url = request.url.replace(
request.host_url.rstrip('/') + api_path +
'/fastdeploy/fastdeploy_client/',
'http://localhost:{}/'.format(port))
resp = requests.request(
method=request.method,
url=proxy_url,
headers={
key: value
for (key, value) in request.headers if key != 'Host'
},
data=request.get_data(),
cookies=request.cookies,
allow_redirects=False)
if path == 'app':
content = resp.content
if request_args and 'server_id' in request_args:
server_id = request_args.get('server_id')
start_args = get_start_arguments(server_id)
http_port = start_args.get('http-port', '')
metrics_port = start_args.get('metrics-port', '')
model_name = start_args.get('default_model_name', '')
content = content.decode()
try:
default_server_addr = re.search(
'"label": {}.*?"value": "".*?}}'.format(
json.dumps("服务ip", ensure_ascii=True).replace(
'\\', '\\\\')), content).group(0)
cur_server_addr = default_server_addr.replace(
'"value": ""', '"value": "localhost"')
default_http_port = re.search(
'"label": {}.*?"value": "".*?}}'.format(
json.dumps("推理服务端口", ensure_ascii=True).replace(
'\\', '\\\\')), content).group(0)
cur_http_port = default_http_port.replace(
'"value": ""', '"value": "{}"'.format(http_port))
default_metrics_port = re.search(
'"label": {}.*?"value": "".*?}}'.format(
json.dumps("性能服务端口", ensure_ascii=True).replace(
'\\', '\\\\')), content).group(0)
cur_metrics_port = default_metrics_port.replace(
'"value": ""', '"value": "{}"'.format(metrics_port))
default_model_name = re.search(
'"label": {}.*?"value": "".*?}}'.format(
json.dumps("模型名称", ensure_ascii=True).replace(
'\\', '\\\\')), content).group(0)
cur_model_name = default_model_name.replace(
'"value": ""', '"value": "{}"'.format(model_name))
default_model_version = re.search(
'"label": {}.*?"value": "".*?}}'.format(
json.dumps("模型版本", ensure_ascii=True).replace(
'\\', '\\\\')), content).group(0)
cur_model_version = default_model_version.replace(
'"value": ""', '"value": "{}"'.format('1'))
content = content.replace(default_server_addr,
cur_server_addr)
if http_port:
content = content.replace(default_http_port,
cur_http_port)
if metrics_port:
content = content.replace(default_metrics_port,
cur_metrics_port)
if model_name:
content = content.replace(default_model_name,
cur_model_name)
content = content.replace(default_model_version,
cur_model_version)
except Exception:
pass
finally:
content = content.encode()
else:
content = resp.content
headers = [(name, value) for (name, value) in resp.raw.headers.items()]
response = Response(content, resp.status_code, headers)
return response
@app.route(api_path + '/component_tabs')
def component_tabs():
data, mimetype, headers = get_component_tabs(
......
......@@ -78,7 +78,8 @@ def validate_args(args):
supported_tabs = [
'scalar', 'image', 'text', 'embeddings', 'audio', 'histogram',
'hyper_parameters', 'static_graph', 'dynamic_graph', 'pr_curve',
'roc_curve', 'profiler', 'x2paddle', 'fastdeploy_server'
'roc_curve', 'profiler', 'x2paddle', 'fastdeploy_server',
'fastdeploy_client'
]
if args.component_tabs is not None:
for component_tab in args.component_tabs:
......
......@@ -23,6 +23,7 @@ USER_HOME = os.path.expanduser('~')
VDL_HOME = os.path.join(USER_HOME, '.visualdl')
CONF_HOME = os.path.join(VDL_HOME, 'conf')
CONFIG_PATH = os.path.join(CONF_HOME, 'config.json')
FASTDEPLOYSERVER_PATH = os.path.join(VDL_HOME, 'fastdeployserver')
X2PADDLE_CACHE_PATH = os.path.join(VDL_HOME, 'x2paddle')
......@@ -32,5 +33,7 @@ def init_vdl_config():
if not os.path.exists(CONFIG_PATH) or 0 == os.path.getsize(CONFIG_PATH):
with open(CONFIG_PATH, 'w') as fp:
fp.write(json.dumps(default_vdl_config))
if not os.path.exists(FASTDEPLOYSERVER_PATH):
os.makedirs(FASTDEPLOYSERVER_PATH, exist_ok=True)
if not os.path.exists(X2PADDLE_CACHE_PATH):
os.makedirs(X2PADDLE_CACHE_PATH, exist_ok=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册