diff --git a/test/tree/integration/test_torch.py b/test/tree/integration/test_torch.py index 19871568657cd37625d9c83f2ea1cabfea0e6790..a9a0a8e1cb21ec47f0726fa99cd3c73becb6b39b 100644 --- a/test/tree/integration/test_torch.py +++ b/test/tree/integration/test_torch.py @@ -1,3 +1,4 @@ +from typing import Tuple, Mapping from unittest import skipUnless import pytest @@ -94,3 +95,69 @@ class TestTreeIntegrationTorch: assert _t_isclose(foox(x, y), x + y).all() == \ FastTreeValue({'a': torch.tensor(True), 'b': torch.tensor(True)}) + + @skipUnless(vpip('torch') >= '2.0.0' and OS.linux and vpython < '3.11', 'Torch 2 on linux platform required') + def test_with_module(self): + from torch import nn + + class MLP(nn.Module): + def __init__(self, in_features: int, out_features: int, layers: Tuple[int, ...] = (1024,)): + nn.Module.__init__(self) + self.in_features = in_features + self.out_features = out_features + self.layers = layers + ios = [self.in_features, *self.layers, self.out_features] + self.mlp = nn.Sequential( + *( + nn.Linear(in_, out_, bias=True) + for in_, out_ in zip(ios[:-1], ios[1:]) + ) + ) + + def forward(self, x): + return self.mlp(x) + + class MultiHeadMLP(nn.Module): + def __init__(self, in_features: int, out_features: Mapping[str, int], layers: Tuple[int, ...] = (1024,)): + nn.Module.__init__(self) + self.in_features = in_features + self.out_features = out_features + self.layers = layers + _networks = { + o_name: MLP(in_features, o_feat, layers) + for o_name, o_feat in self.out_features.items() + } + self.mlps = nn.ModuleDict(_networks) + self._t_mlps = FastTreeValue(_networks) + + def forward(self, x): + return self._t_mlps(x) + + net = MultiHeadMLP( + 20, + {'a': 10, 'b': 20, 'c': 14, 'd': 3}, + ) + net = torch.compile(net) + + input1 = torch.randn(3, 20) + output1 = net(input1) + assert output1.shape == FastTreeValue({ + 'a': torch.Size([3, 10]), + 'b': torch.Size([3, 20]), + 'c': torch.Size([3, 14]), + 'd': torch.Size([3, 3]), + }) + + input2 = FastTreeValue.func()(torch.randn)(FastTreeValue({ + 'a': (3, 20), + 'b': (4, 20), + 'c': (20,), + 'd': (2, 5, 20), + })) + output2 = net(input2) + assert output2.shape == FastTreeValue({ + 'a': torch.Size([3, 10]), + 'b': torch.Size([4, 20]), + 'c': torch.Size([14]), + 'd': torch.Size([2, 5, 3]), + })