未验证 提交 77d5d106 编写于 作者: M Meiyim 提交者: GitHub

update pretrain demo (#491)

* update pretrain demo

* modeul path fix for ernie-gen

* dygraph distributed cls + init_checkpoint option
上级 0806468e
......@@ -53,6 +53,7 @@ if __name__ == '__main__':
parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
parser.add_argument('--save_dir', type=str, default=None, help='model output directory')
parser.add_argument('--wd', type=int, default=0.01, help='weight decay, aka L2 regularizer')
parser.add_argument('--init_checkpoint', type=str, default=None, help='checkpoint to warm start from')
args = parser.parse_args()
......@@ -99,6 +100,12 @@ if __name__ == '__main__':
with FD.guard(place):
ctx = FD.parallel.prepare_context()
model = ErnieModelForSequenceClassification.from_pretrained(args.from_pretrained, num_labels=3, name='')
if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint)
sd, _ = FD.load_dygraph(args.init_checkpoint)
model.set_dict(sd)
model = FD.parallel.DataParallel(model, ctx)
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
......
......@@ -28,7 +28,7 @@ example:
make pretrain data with:
```script
python3 ernie/pretrain/make_pretrain_data.py input_file output_file.gz --vocab ./pretrained/vocab.txt
python3 ./demo/pretrain/make_pretrain_data.py input_file output_file.gz --vocab /path/to/ernie1.0/vocab.txt
```
2. run distributed pretrain
......@@ -36,9 +36,9 @@ python3 ernie/pretrain/make_pretrain_data.py input_file output_file.gz --vocab
```sript
python3 -m paddle.distributed.launch \
./ernie/pretrain/pretrain_dygraph.py \
--data_dir data/* \
--from_pretrained ./ernie_1.0_pretrain_dir/
./demo/pretrain/pretrain_dygraph.py \
--data_dir "data/*.gz" \
--from_pretrained /path/to/ernie1.0_pretrain_dir/
```
......@@ -124,7 +124,7 @@ if __name__ == '__main__':
log.setLevel(logging.DEBUG)
from tokenizing_ernie import _wordpiece
from ernie.tokenizing_ernie import _wordpiece
pat = re.compile(r'([a-zA-Z0-9]+|\S)')
vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.vocab, 'rb'))}
......
......@@ -37,7 +37,7 @@ from ernie.modeling_ernie import _build_linear, _build_ln, append_name
from ernie.tokenizing_ernie import ErnieTokenizer
from ernie.optimization import AdamW, LinearDecay
from experimental.seq2seq.decode import beam_search_infilling, post_process
from demo.seq2seq.decode import beam_search_infilling, post_process
from propeller import log
import propeller.paddle as propeller
......@@ -295,7 +295,7 @@ if __name__ == '__main__':
parser.add_argument('--predict_output_dir', type=str, default=None, help='predict file output directory')
parser.add_argument('--attn_token', type=str, default='[ATTN]', help='if [ATTN] not in vocab, you can specified [MAKK] as attn-token')
parser.add_argument('--inference_model_dir', type=str, default=None, help='inference model output directory')
parser.add_argument('--init_checkpoint', type=str, default=None)
parser.add_argument('--init_checkpoint', type=str, default=None, help='checkpoint to warm start from')
parser.add_argument('--save_dir', type=str, default=None, help='model output directory')
parser.add_argument('--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册