提交 d2fece3e 编写于 作者: W wuzewu

Update comment

上级 1562d65f
......@@ -17,7 +17,7 @@ import os
import pickle
import time
from collections import defaultdict
from typing import Any, Callable
from typing import Any, Callable, List
import paddle
from paddle.distributed import ParallelEnv
......@@ -29,7 +29,25 @@ from paddlehub.utils.utils import Timer
class Trainer(object):
'''
Trainer
Model trainer
Args:
model(paddle.nn.Layer) : Model to train or evaluate.
strategy(paddle.optimizer.Optimizer) : Optimizer strategy.
use_vdl(bool) : Whether to use visualdl to record training data.
checkpoint_dir(str) : Directory where the checkpoint is saved, and the trainer will restore the
state and model parameters from the checkpoint.
compare_metrics(callable) : The method of comparing the model metrics. If not specified, the main
metric return by `validation_step` will be used for comparison by default, the larger the
value, the better the effect. This method will affect the saving of the best model. If the
default behavior does not meet your requirements, please pass in a custom method.
Example:
.. code-block:: python
def compare_metrics(old_metric: dict, new_metric: dict):
mainkey = list(new_metric.keys())[0]
return old_metric[mainkey] < new_metric[mainkey]
'''
def __init__(self,
......@@ -130,7 +148,8 @@ class Trainer(object):
epochs(int) : Number of training loops, default is 1.
batch_size(int) : Batch size of per step, default is 1.
num_workers(int) : Number of subprocess to load data, default is 0.
eval_dataset(paddle.io.Dataset) : The validation dataset, deafult is None. If set, the Trainer will execute evaluate function every `save_interval` epochs.
eval_dataset(paddle.io.Dataset) : The validation dataset, deafult is None. If set, the Trainer will
execute evaluate function every `save_interval` epochs.
log_interval(int) : Log the train infomation every `log_interval` steps.
save_interval(int) : Save the checkpoint every `save_interval` epochs.
'''
......@@ -269,7 +288,14 @@ class Trainer(object):
return {'loss': avg_loss, 'metrics': avg_metrics}
return {'metrics': avg_metrics}
def training_step(self, batch: Any, batch_idx: int):
def training_step(self, batch: List[paddle.Tensor], batch_idx: int):
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]) : The one batch data
batch_idx(int) : The index of batch.
'''
if self.nranks > 1:
result = self.model._layers.training_step(batch, batch_idx)
else:
......@@ -296,17 +322,42 @@ class Trainer(object):
return loss, metrics
def validation_step(self, batch: Any, batch_idx: int):
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]) : The one batch data
batch_idx(int) : The index of batch.
'''
if self.nranks > 1:
result = self.model._layers.validation_step(batch, batch_idx)
else:
result = self.model.validation_step(batch, batch_idx)
return result
def optimizer_step(self, current_epoch: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer,
def optimizer_step(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer,
loss: paddle.Tensor):
'''
One step for optimize.
Args:
epoch_idx(int) : The index of epoch.
batch_idx(int) : The index of batch.
optimizer(paddle.optimizer.Optimizer) : Optimizer used.
loss(paddle.Tensor) : Loss tensor.
'''
self.optimizer.minimize(loss)
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer):
def optimizer_zero_grad(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer):
'''
One step for clear gradients.
Args:
epoch_idx(int) : The index of epoch.
batch_idx(int) : The index of batch.
optimizer(paddle.optimizer.Optimizer) : Optimizer used.
loss(paddle.Tensor) : Loss tensor.
'''
self.model.clear_gradients()
def _compare_metrics(self, old_metric: dict, new_metric: dict):
......
......@@ -38,8 +38,8 @@ class ImageClassifierModule(RunModule, ImageServing):
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Variable]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
batch(list[paddle.Tensor]) : The one batch data, which contains images and labels.
batch_idx(int) : The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
......@@ -51,8 +51,8 @@ class ImageClassifierModule(RunModule, ImageServing):
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Variable]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
batch(list[paddle.Tensor]) : The one batch data, which contains images and labels.
batch_idx(int) : The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
......@@ -80,7 +80,7 @@ class ImageClassifierModule(RunModule, ImageServing):
images = self.transforms(images)
if len(images.shape) == 3:
images = images[np.newaxis, :]
preds = self(paddle.to_variable(images))
preds = self(paddle.to_tensor(images))
preds = F.softmax(preds, axis=1).numpy()
pred_idxs = np.argsort(preds)[::-1][:, :top_k]
res = []
......@@ -91,6 +91,3 @@ class ImageClassifierModule(RunModule, ImageServing):
res_dict[class_name] = preds[i][k]
res.append(res_dict)
return res
def is_better_score(self, old_score: dict, new_score: dict):
return old_score['acc'] < new_score['acc']
......@@ -29,7 +29,15 @@ from paddlehub.utils import log
class GitSource(object):
def __init__(self, url, path=None):
'''
Git source for PaddleHub module
Args:
url(str) : Url of git repository
path(str) : Path to store the git repository
'''
def __init__(self, url: str, path: str = None):
self.url = url
self._parse_result = urlparse(self.url)
......@@ -66,10 +74,25 @@ class GitSource(object):
log.logger.warning('An error occurred while loading {}'.format(self.path))
sys.path.remove(self.path)
def search_module(self, name, version=None):
def search_module(self, name: str, version: str = None) -> dict:
'''
Search PaddleHub module
Args:
name(str) : PaddleHub module name
version(str) : PaddleHub module version
'''
return self.search_resouce(type='module', name=name, version=version)
def search_resouce(self, type, name, version=None):
def search_resouce(self, type: str, name: str, version: str = None) -> dict:
'''
Search PaddleHub Resource
Args:
type(str) : Resource type
name(str) : Resource name
version(str) : Resource version
'''
module = self.hub_modules.get(name, None)
if module and module.version.match(version):
return {
......@@ -82,7 +105,13 @@ class GitSource(object):
return None
@classmethod
def check(cls, url):
def check(cls, url: str) -> bool:
'''
Check if the specified url is a valid git repository link
Args:
url(str) : Url to check
'''
try:
git.cmd.Git().ls_remote(url)
return True
......
......@@ -21,10 +21,12 @@ PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub'
class HubServer(object):
'''PaddleHub server'''
def __init__(self):
self.sources = OrderedDict()
def _generate_source(self, url):
def _generate_source(self, url: str):
if ServerSource.check(url):
source = ServerSource(url)
elif GitSource.check(url):
......@@ -33,17 +35,34 @@ class HubServer(object):
raise RuntimeError()
return source
def add_source(self, url, key=None):
def add_source(self, url: str, key: str = None):
'''Add a module source(GitSource or ServerSource)'''
key = "source_{}".format(len(self.sources)) if not key else key
self.sources[key] = self._generate_source(url)
def remove_source(self, url=None, key=None):
def remove_source(self, url: str = None, key: str = None):
'''Remove a module source'''
self.sources.pop(key)
def search_module(self, name, version=None, source=None):
def search_module(self, name: str, version: str = None, source: str = None) -> dict:
'''
Search PaddleHub module
Args:
name(str) : PaddleHub module name
version(str) : PaddleHub module version
'''
return self.search_resouce(type='module', name=name, version=version, source=source)
def search_resouce(self, type, name, version=None, source=None):
def search_resouce(self, type: str, name: str, version: str = None, source: str = None) -> dict:
'''
Search PaddleHub Resource
Args:
type(str) : Resource type
name(str) : Resource name
version(str) : Resource version
'''
sources = self.sources.values() if not source else [self._generate_source(source)]
for source in sources:
result = source.search_resouce(name=name, type=type, version=version)
......
......@@ -25,14 +25,37 @@ from paddlehub.utils import utils
class ServerSource(object):
def __init__(self, url, timeout=10):
'''
PaddleHub server source
Args:
url(str) : Url of the server
timeout(int) : Request timeout
'''
def __init__(self, url: str, timeout: int = 10):
self._url = url
self._timeout = timeout
def search_module(self, name, version=None):
def search_module(self, name: str, version: str = None) -> dict:
'''
Search PaddleHub module
Args:
name(str) : PaddleHub module name
version(str) : PaddleHub module version
'''
return self.search_resouce(type='module', name=name, version=version)
def search_resouce(self, type, name, version=None):
def search_resouce(self, type: str, name: str, version: str = None) -> dict:
'''
Search PaddleHub Resource
Args:
type(str) : Resource type
name(str) : Resource name
version(str) : Resource version
'''
payload = {'environments': {}}
payload['word'] = name
......@@ -59,7 +82,13 @@ class ServerSource(object):
return None
@classmethod
def check(cls, url):
def check(cls, url: str) -> bool:
'''
Check if the specified url is a valid paddlehub server
Args:
url(str) : Url to check
'''
try:
r = requests.get(url + '/search')
return r.status_code == 200
......
......@@ -50,7 +50,7 @@ def redirect_estream(stream: IO):
@contextlib.contextmanager
def discard_oe():
'''
Redirect input and output stream to temporary file. In a sense,
Redirect output and error stream to temporary file. In a sense,
it is equivalent discarded the output and error messages
'''
with generate_tempfile(mode='w') as _stream:
......
......@@ -96,15 +96,16 @@ class ProgressBar(object):
Examples:
.. code-block:: python
with ProgressBar('Download module') as bar:
for i in range(100):
bar.update(i / 100)
# with continuous bar.update, the progress bar in the terminal
# will continue to update until 100%
#
# Download module
# [##################################################] 100.00%
with ProgressBar('Download module') as bar:
for i in range(100):
bar.update(i / 100)
# with continuous bar.update, the progress bar in the terminal
# will continue to update until 100%
#
# Download module
# [##################################################] 100.00%
'''
def __init__(self, title: str, flush_interval: float = 0.1):
......@@ -126,6 +127,10 @@ class ProgressBar(object):
def update(self, progress: float):
'''
Update progress bar
Args:
progress: Processing progress, from 0.0 to 1.0
'''
msg = '[{:<50}] {:.2f}%'.format('#' * int(progress * 50), progress * 100)
need_flush = (time.time() - self.last_flush_time) >= self.flush_interval
......@@ -146,14 +151,14 @@ class FormattedText(object):
Args:
text(str) : Text content
width(int) : Text length, if the text is less than the specified length, it will be filled with spaces
align(str) : it must be:
======== ==================
align(str) : Text alignment, it must be:
======== ====================================
Charater Meaning
-------- ------------------
'<' left aligned
'^' middle aligned
'>' right aligned
======== ==================
-------- ------------------------------------
'<' The text will remain left aligned
'^' The text will remain middle aligned
'>' The text will remain right aligned
======== ====================================
color(str) : Text color, default is None(depends on terminal configuration)
'''
_MAP = {'red': Fore.RED, 'yellow': Fore.YELLOW, 'green': Fore.GREEN, 'blue': Fore.BLUE}
......@@ -293,12 +298,13 @@ class Table(object):
Table with adaptive width and height
Args:
colors(list[str]) : Text colors of contents one by one
aligns(list[str]) : Text aligns of contents one by one
widths(list[str]) : Text widths of contents one by one
colors(list[str]) : Text colors
aligns(list[str]) : Text alignments
widths(list[str]) : Text widths
Examples:
.. code-block:: python
table = Table(widths=[12, 20])
table.append('name', 'PaddleHub')
table.append('version', '2.0.0')
......@@ -337,9 +343,9 @@ class Table(object):
Args:
*contents(*list): Contents of the row, each content will be placed in a separate cell
colors(list[str]) : Text colors of contents one by one, if not set, the default value will be used.
aligns(list[str]) : Text aligns of contents one by one, if not set, the default value will be used.
widths(list[str]) : Text widths of contents one by one, if not set, the default value will be used.
colors(list[str]) : Text colors
aligns(list[str]) : Text alignments
widths(list[str]) : Text widths
'''
newrow = TableRow()
......
......@@ -32,7 +32,7 @@ import paddlehub.env as hubenv
class Version(packaging.version.Version):
'''Expand realization of packaging.version.Version'''
'''Extended implementation of packaging.version.Version'''
def match(self, condition: str) -> bool:
'''
......@@ -45,9 +45,9 @@ class Version(packaging.version.Version):
bool: True if the given version condition are met, else False
Examples:
from paddlehub.utils import Version
.. code-block:: python
Version('1.2.0').match('>=1.2.0a')
Version('1.2.0').match('>=1.2.0a')
'''
if not condition:
return True
......@@ -162,7 +162,6 @@ def download(url: str, path: str = None) -> str:
Examples:
.. code-block:: python
from paddlehub.utils.utils import download
url = 'https://xxxxx.xx/xx.tar.gz'
download(url, path='./output')
......@@ -182,7 +181,6 @@ def download_with_progress(url: str, path: str = None) -> Generator[str, int, in
Examples:
.. code-block:: python
from paddlehub.utils.utils import download_with_progress
url = 'https://xxxxx.xx/xx.tar.gz'
for filename, download_size, total_szie in download_with_progress(url, path='./output'):
......
......@@ -177,7 +177,6 @@ def archive(filename: str, recursive: bool = True, exclude: Callable = None, arc
Examples:
.. code-block:: python
from paddlehub.utils import archive
archive_path = '/PATH/TO/FILE'
archive(archive_path, arcname='output.tar.gz', arctype='tar.gz')
......@@ -200,7 +199,6 @@ def unarchive(name: str, path: str):
Examples:
.. code-block:: python
from paddlehub.utils import unarchive
unarchive_path = '/PATH/TO/FILE'
unarchive(unarchive_path, path='./output')
......@@ -219,7 +217,6 @@ def unarchive_with_progress(name: str, path: str) -> Generator[str, int, int]:
Examples:
.. code-block:: python
from paddlehub.utils.xarfile import unarchive_with_progress
unarchive_path = 'test.tar.gz'
for filename, extract_size, total_szie in unarchive_with_progress(unarchive_path, path='./output'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册