未验证 提交 17dbe2b1 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #93 from heavengate/update_tsm

update TSM use new APIS
...@@ -19,8 +19,8 @@ import os ...@@ -19,8 +19,8 @@ import os
import argparse import argparse
import numpy as np import numpy as np
from paddle.incubate.hapi.model import Input, set_device import paddle
from paddle.incubate.hapi.vision.transforms import Compose from paddle.vision.transforms import Compose
from check import check_gpu, check_version from check import check_gpu, check_version
from modeling import tsm_resnet50 from modeling import tsm_resnet50
...@@ -33,8 +33,8 @@ logger = logging.getLogger(__name__) ...@@ -33,8 +33,8 @@ logger = logging.getLogger(__name__)
def main(): def main():
device = set_device(FLAGS.device) device = paddle.set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None paddle.disable_static(device) if FLAGS.dynamic else None
transform = Compose([GroupScale(), GroupCenterCrop(), NormalizeImage()]) transform = Compose([GroupScale(), GroupCenterCrop(), NormalizeImage()])
dataset = KineticsDataset( dataset = KineticsDataset(
...@@ -47,9 +47,7 @@ def main(): ...@@ -47,9 +47,7 @@ def main():
model = tsm_resnet50( model = tsm_resnet50(
num_classes=len(labels), pretrained=FLAGS.weights is None) num_classes=len(labels), pretrained=FLAGS.weights is None)
inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')] model.prepare()
model.prepare(inputs=inputs, device=FLAGS.device)
if FLAGS.weights is not None: if FLAGS.weights is not None:
model.load(FLAGS.weights, reset_optimizer=True) model.load(FLAGS.weights, reset_optimizer=True)
......
...@@ -19,13 +19,11 @@ import os ...@@ -19,13 +19,11 @@ import os
import argparse import argparse
import numpy as np import numpy as np
import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.incubate.hapi.model import Model, Input, set_device from paddle.vision.transforms import Compose
from paddle.incubate.hapi.loss import CrossEntropy
from paddle.incubate.hapi.metrics import Accuracy
from paddle.incubate.hapi.vision.transforms import Compose
from modeling import tsm_resnet50 from modeling import tsm_resnet50
from check import check_gpu, check_version from check import check_gpu, check_version
...@@ -50,8 +48,8 @@ def make_optimizer(step_per_epoch, parameter_list=None): ...@@ -50,8 +48,8 @@ def make_optimizer(step_per_epoch, parameter_list=None):
def main(): def main():
device = set_device(FLAGS.device) device = paddle.set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None paddle.disable_static(device) if FLAGS.dynamic else None
train_transform = Compose([ train_transform = Compose([
GroupScale(), GroupMultiScaleCrop(), GroupRandomCrop(), GroupScale(), GroupMultiScaleCrop(), GroupRandomCrop(),
...@@ -79,16 +77,10 @@ def main(): ...@@ -79,16 +77,10 @@ def main():
/ ParallelEnv().nranks) / ParallelEnv().nranks)
optim = make_optimizer(step_per_epoch, model.parameters()) optim = make_optimizer(step_per_epoch, model.parameters())
inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare( model.prepare(
optim, optimizer=optim,
CrossEntropy(), loss=paddle.nn.CrossEntropyLoss(),
metrics=Accuracy(topk=(1, 5)), metrics=paddle.metric.Accuracy(topk=(1, 5)))
inputs=inputs,
labels=labels,
device=FLAGS.device)
if FLAGS.eval_only: if FLAGS.eval_only:
if FLAGS.weights is not None: if FLAGS.weights is not None:
......
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
#limitations under the License. #limitations under the License.
import math import math
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.incubate.hapi.model import Model from paddle.static import InputSpec
from paddle.incubate.hapi.download import get_weights_path_from_url from paddle.utils.download import get_weights_path_from_url
__all__ = ["TSM_ResNet", "tsm_resnet50"] __all__ = ["TSM_ResNet", "tsm_resnet50"]
...@@ -112,7 +113,7 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -112,7 +113,7 @@ class BottleneckBlock(fluid.dygraph.Layer):
return y return y
class TSM_ResNet(Model): class TSM_ResNet(fluid.dygraph.Layer):
""" """
TSM network with ResNet as backbone TSM network with ResNet as backbone
...@@ -193,7 +194,10 @@ class TSM_ResNet(Model): ...@@ -193,7 +194,10 @@ class TSM_ResNet(Model):
def _tsm_resnet(num_layers, seg_num=8, num_classes=400, pretrained=True): def _tsm_resnet(num_layers, seg_num=8, num_classes=400, pretrained=True):
model = TSM_ResNet(num_layers, seg_num, num_classes) inputs = [InputSpec([None, 8, 3, 224, 224], 'float32', name='image')]
labels = [InputSpec([None, 1], 'int64', name='label')]
net = TSM_ResNet(num_layers, seg_num, num_classes)
model = paddle.Model(net, inputs, labels)
if pretrained: if pretrained:
assert num_layers in pretrain_infos.keys(), \ assert num_layers in pretrain_infos.keys(), \
"TSM-ResNet{} do not have pretrained weights now, " \ "TSM-ResNet{} do not have pretrained weights now, " \
...@@ -201,6 +205,8 @@ def _tsm_resnet(num_layers, seg_num=8, num_classes=400, pretrained=True): ...@@ -201,6 +205,8 @@ def _tsm_resnet(num_layers, seg_num=8, num_classes=400, pretrained=True):
weight_path = get_weights_path_from_url(*(pretrain_infos[num_layers])) weight_path = get_weights_path_from_url(*(pretrain_infos[num_layers]))
assert weight_path.endswith('.pdparams'), \ assert weight_path.endswith('.pdparams'), \
"suffix of weight must be .pdparams" "suffix of weight must be .pdparams"
# weight_dict, _ = fluid.load_dygraph(weight_path)
# model.set_dict(weight_dict)
model.load(weight_path) model.load(weight_path)
return model return model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册