提交 f83db2f4 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): fix bug of jax integration

上级 421d9363
......@@ -7,6 +7,7 @@ from treevalue import FastTreeValue, register_for_jax
try:
import jax
from jax.tree_util import register_pytree_node
except (ModuleNotFoundError, ImportError):
jax = None
......
......@@ -6,6 +6,7 @@ from treevalue import FastTreeValue, register_for_torch
try:
import torch
from torch.utils._pytree import _register_pytree_node
except (ImportError, ModuleNotFoundError):
torch = None
......
......@@ -3,13 +3,16 @@ from functools import wraps
try:
import jax
from jax.tree_util import register_pytree_node
except (ModuleNotFoundError, ImportError):
from .cjax import register_for_jax as _original_register_for_jax
@wraps(_original_register_for_jax)
def register_for_jax(cls):
warnings.warn(f'Jax is not installed, registration of {cls!r} will be ignored.')
warnings.warn(f'Jax doesn\'t have tree_util module due to either not installed '
f'or the installed version is too low, '
f'so the registration of {cls!r} will be ignored.')
else:
from .cjax import register_for_jax
from ..tree import TreeValue
......
......@@ -3,13 +3,16 @@ from functools import wraps
try:
import torch
from torch.utils._pytree import _register_pytree_node
except (ModuleNotFoundError, ImportError):
from .ctorch import register_for_torch as _original_register_for_torch
@wraps(_original_register_for_torch)
def register_for_torch(cls):
warnings.warn(f'Torch is not installed, registration of {cls!r} will be ignored.')
warnings.warn(f'Pytree module is not included in the Torch installation '
f'or the installed version is too low, '
f'so the registration of {cls!r} will be ignored.')
else:
from .ctorch import register_for_torch
from ..tree import TreeValue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册