Global rank: 0, Weights array: [DTensor(local_tensor=tensor([
[101., 102., 103., 104., 105., 106., 107., 108., 109., 110.],
[111., 112., 113., 114., 115., 116., 117., 118., 119., 120.],
[121., 122., 123., 124., 125., 126., 127., 128., 129., 130.],
[131., 132., 133., 134., 135., 136., 137., 138., 139., 140.],
[141., 142., 143., 144., 145., 146., 147., 148., 149., 150.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=0),)), DTensor(local_tensor=tensor([
[101., 102., 103., 104., 105., 106., 107., 108., 109., 110.],
[111., 112., 113., 114., 115., 116., 117., 118., 119., 120.],
[121., 122., 123., 124., 125., 126., 127., 128., 129., 130.],
[131., 132., 133., 134., 135., 136., 137., 138., 139., 140.],
[141., 142., 143., 144., 145., 146., 147., 148., 149., 150.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=0),))]
# Input, output for net1
Input=== from: (Replicate(),) to: (Replicate(),)
Output=== from: (Shard(dim=1),) to: (Shard(dim=-1),)
# Input, output for net 2
Input=== from: (Shard(dim=-1),) to: (Replicate(),)
[rank0]:W1024 06:58:20.513000 46578 torch/distributed/tensor/_redistribute.py:202] redistribute from S(1) to R on mesh dim 0
Output=== from: (Shard(dim=1),) to: (Shard(dim=-1),)