Pytorch——合并与分割

cat

torch.cat([a,b],dim=)
dim以外的纬度相同

stack

torch.stack([a,b],dim=)
a,b纬度必须完全一样,在dim前添加纬度合并

a=torch.rand(32,8)
b=torch.rand(32,8)
torch.stack([a,b].dim=0).shape
#形状是[2,32,8]

split

参数1
a1,a2,a3...an=a.split([m1,m2,m3...mn],dim=)    m1+m2+m3+...+mn=dim的size
参数2
a1,a2,a3...an=a.split(n,dim=)    n为均分的步长

chunk

c.chunk(num,dim=) # 按要拆分出的tensor的数量拆分