提交 25968ad1 编写于 作者: Eric.Lee2021's avatar Eric.Lee2021 🚴🏻

add model arh

上级 adb69517
...@@ -60,7 +60,7 @@ def main_wyw2s(video_path,cfg_file): ...@@ -60,7 +60,7 @@ def main_wyw2s(video_path,cfg_file):
facebank_path = config["facebank_path"], facebank_path = config["facebank_path"],
threshold = float(config["face_verify_threshold"])) threshold = float(config["face_verify_threshold"]))
face_multitask_model = FaceMuitiTask_Model(model_path = config["face_multitask_model_path"]) face_multitask_model = FaceMuitiTask_Model(model_path = config["face_multitask_model_path"], model_arch = config["face_multitask_model_arch"])
face_euler_model = FaceAngle_Model(model_path = config["face_euler_model_path"]) face_euler_model = FaceAngle_Model(model_path = config["face_euler_model_path"])
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import cv2 import cv2
import torch.nn.functional as F import torch.nn.functional as F
from face_multi_task.network.resnet import resnet50 from face_multi_task.network.resnet import resnet50,resnet34,resnet18
from face_multi_task.utils.common_utils import * from face_multi_task.utils.common_utils import *
import numpy as np import numpy as np
...@@ -17,6 +17,7 @@ class FaceMuitiTask_Model(object): ...@@ -17,6 +17,7 @@ class FaceMuitiTask_Model(object):
model_path = './components/face_multi_task/weights_multask/resnet_50_imgsize-256-20210411.pth', model_path = './components/face_multi_task/weights_multask/resnet_50_imgsize-256-20210411.pth',
img_size=256, img_size=256,
num_classes = 196,# 人脸关键点,年龄,性别 num_classes = 196,# 人脸关键点,年龄,性别
model_arch = "resnet50",# 模型结构
): ):
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
...@@ -24,7 +25,13 @@ class FaceMuitiTask_Model(object): ...@@ -24,7 +25,13 @@ class FaceMuitiTask_Model(object):
self.device = torch.device("cuda:0" if use_cuda else "cpu") # 可选的设备类型及序号 self.device = torch.device("cuda:0" if use_cuda else "cpu") # 可选的设备类型及序号
self.img_size = img_size self.img_size = img_size
#----------------------------------------------------------------------- #-----------------------------------------------------------------------
face_multi_model = resnet50(landmarks_num=num_classes, img_size=img_size) if model_arch == "resnet50":
face_multi_model = resnet50(landmarks_num=num_classes, img_size=img_size)
elif model_arch == "resnet34":
face_multi_model = resnet34(landmarks_num=num_classes, img_size=img_size)
elif model_arch == "resnet18":
face_multi_model = resnet18(landmarks_num=num_classes, img_size=img_size)
chkpt = torch.load(model_path, map_location=lambda storage, loc: storage) chkpt = torch.load(model_path, map_location=lambda storage, loc: storage)
face_multi_model.load_state_dict(chkpt) face_multi_model.load_state_dict(chkpt)
face_multi_model.eval() face_multi_model.eval()
......
...@@ -11,6 +11,10 @@ face_verify_backbone_path=./wyw2s_models/face_verify-model_ir_se-50.pth ...@@ -11,6 +11,10 @@ face_verify_backbone_path=./wyw2s_models/face_verify-model_ir_se-50.pth
facebank_path=./wyw2s_models/facebank facebank_path=./wyw2s_models/facebank
face_verify_threshold=1.2 face_verify_threshold=1.2
face_multitask_model_path=./wyw2s_models/face_multitask-resnet_50_imgsize-256-20210411.pth #face_multitask_model_path=./wyw2s_models/face_multitask-resnet_50_imgsize-256-20210411.pth
#face_multitask_model_arch=resnet50
face_multitask_model_path=./wyw2s_models/face_multitask-resnet_34_imgsize-256-20210423.pth
face_multitask_model_arch=resnet34
face_euler_model_path=./wyw2s_models/euler_angle-resnet_18_imgsize_256.pth face_euler_model_path=./wyw2s_models/euler_angle-resnet_18_imgsize_256.pth
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册