未验证 提交 17dbe2b1 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #93 from heavengate/update_tsm

update TSM use new APIS
......@@ -19,8 +19,8 @@ import os
import argparse
import numpy as np
from paddle.incubate.hapi.model import Input, set_device
from paddle.incubate.hapi.vision.transforms import Compose
import paddle
from paddle.vision.transforms import Compose
from check import check_gpu, check_version
from modeling import tsm_resnet50
......@@ -33,8 +33,8 @@ logger = logging.getLogger(__name__)
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
device = paddle.set_device(FLAGS.device)
paddle.disable_static(device) if FLAGS.dynamic else None
transform = Compose([GroupScale(), GroupCenterCrop(), NormalizeImage()])
dataset = KineticsDataset(
......@@ -47,9 +47,7 @@ def main():
model = tsm_resnet50(
num_classes=len(labels), pretrained=FLAGS.weights is None)
inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')]
model.prepare(inputs=inputs, device=FLAGS.device)
model.prepare()
if FLAGS.weights is not None:
model.load(FLAGS.weights, reset_optimizer=True)
......
......@@ -19,13 +19,11 @@ import os
import argparse
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.incubate.hapi.model import Model, Input, set_device
from paddle.incubate.hapi.loss import CrossEntropy
from paddle.incubate.hapi.metrics import Accuracy
from paddle.incubate.hapi.vision.transforms import Compose
from paddle.vision.transforms import Compose
from modeling import tsm_resnet50
from check import check_gpu, check_version
......@@ -50,8 +48,8 @@ def make_optimizer(step_per_epoch, parameter_list=None):
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
device = paddle.set_device(FLAGS.device)
paddle.disable_static(device) if FLAGS.dynamic else None
train_transform = Compose([
GroupScale(), GroupMultiScaleCrop(), GroupRandomCrop(),
......@@ -79,16 +77,10 @@ def main():
/ ParallelEnv().nranks)
optim = make_optimizer(step_per_epoch, model.parameters())
inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(
optim,
CrossEntropy(),
metrics=Accuracy(topk=(1, 5)),
inputs=inputs,
labels=labels,
device=FLAGS.device)
optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy(topk=(1, 5)))
if FLAGS.eval_only:
if FLAGS.weights is not None:
......
......@@ -13,12 +13,13 @@
#limitations under the License.
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.incubate.hapi.model import Model
from paddle.incubate.hapi.download import get_weights_path_from_url
from paddle.static import InputSpec
from paddle.utils.download import get_weights_path_from_url
__all__ = ["TSM_ResNet", "tsm_resnet50"]
......@@ -112,7 +113,7 @@ class BottleneckBlock(fluid.dygraph.Layer):
return y
class TSM_ResNet(Model):
class TSM_ResNet(fluid.dygraph.Layer):
"""
TSM network with ResNet as backbone
......@@ -193,7 +194,10 @@ class TSM_ResNet(Model):
def _tsm_resnet(num_layers, seg_num=8, num_classes=400, pretrained=True):
model = TSM_ResNet(num_layers, seg_num, num_classes)
inputs = [InputSpec([None, 8, 3, 224, 224], 'float32', name='image')]
labels = [InputSpec([None, 1], 'int64', name='label')]
net = TSM_ResNet(num_layers, seg_num, num_classes)
model = paddle.Model(net, inputs, labels)
if pretrained:
assert num_layers in pretrain_infos.keys(), \
"TSM-ResNet{} do not have pretrained weights now, " \
......@@ -201,6 +205,8 @@ def _tsm_resnet(num_layers, seg_num=8, num_classes=400, pretrained=True):
weight_path = get_weights_path_from_url(*(pretrain_infos[num_layers]))
assert weight_path.endswith('.pdparams'), \
"suffix of weight must be .pdparams"
# weight_dict, _ = fluid.load_dygraph(weight_path)
# model.set_dict(weight_dict)
model.load(weight_path)
return model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册