pytorch笔记

pytorch刚上手确实不太容易适应。特别是Andrew给出的1.x的tensorflow代码,和当前torch的差异还是很大的。这里的用法挺琐碎的,用作备忘性质。

one_hot

很多torch的函数都在torch.nn.functional里面。一般地,我们在import的时候都会import这个as F。

1
2
3
4
5
6
7
8
9
10
11
one_hot_matrix = F.one_hot(torch.tensor(labels, dtype=torch.int64), num_classes=C)
one_hot_matrix = one_hot_matrix.T
# 这里dtype必须是int64,不然会报错。labels就是C个种类的具体类别,从0到C-1;而且,如果是想要按照熟悉的样本在一列上,最后还得转置。

# 例如,
labels = torch.tensor([1,2,3,0,2,1])
C = 4
# 得到的矩阵就是tensor([[0, 0, 0, 1, 0, 0],
# [1, 0, 0, 0, 0, 1],
# [0, 1, 0, 0, 1, 0],
# [0, 0, 1, 0, 0, 0]])

当height * width * channel转化成channel * height * width

调用X = np.transpose(X, [0, 3, 1, 2])

我的第一个pytorch多分类网络

苦于没有现成的教程,这个torch的代码编写十分费劲。我尽可能地在这里多写注释,让每一步都知道是怎么来的,便于复现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
import torch.optim as optim # 以上都是必要的库,optim用来声明优化函数
class NeuralNetwork(nn.Module): # 定义我们的网络类,其中父类必须这么写
def __init__(self):
super(NeuralNetwork, self).__init__() # 调用父类init传自身类名,必须写
self.linear_relu_stack= nn.Sequential( # sequential声明自己的网络的结构
nn.Linear(12288, 25),
nn.ReLU(),
nn.Linear(25, 12),
nn.ReLU(),
nn.Linear(12, 6),
)

def forward(self, x): # forward必须写,里面必须return经过网络的结果,这里变量起名为logits
logits = self.linear_relu_stack(x)
return(logits)

# model = NeuralNetwork().to('cpu') # 如果是有显卡的写成cuda,这句话不写也行

class MyDataSet(torch.utils.data.Dataset): # 如果要用自己的dataset,就必须声明自己的数据集类,这三个函数必须要有
def __init__(self, x, y):
super(MyDataSet, self).__init__() # 调用父类init传自身类名,必须写
self.x = torch.tensor(x, dtype=torch.float32).T # 这里是在Andrew作业中比较特别。Andrew习惯将样本特征竖起来,torch习惯横着
self.y = y.T # 同理,而且上面的torch.float32貌似不写会出问题;这里的x为(64*64*3, sampleNum),且已经flatten过(是特例,可改

def __getitem__(self, index): # 当显式地调用对象名[index]的时候必须返回tuple,而且必须是(样本,标签)的顺序
return (self.x[index], self.y[index])

def __len__(self): # 返回dataset中样本数目
return self.x.shape[0]

def train(trainData, trainLabel, savedModel=None): # 正式的训练过程,我把它封装成函数,而且想继续训练的话就传入savedModel
myTrainLoader = torch.utils.data.DataLoader(MyDataSet(trainData, trainLabel), shuffle=True, batch_size=64) # 分别把图片形式的数据做成数据集传入loader,指定batch_size
net = None # 初始化net
if savedModel == None: # 如果不传入训练过的模型就新建一个
net = NeuralNetwork()
else: # 直接使用训练过的模型
net = savedModel
optimizer = optim.Adam(net.parameters(), lr=0.00001) # 声明梯度下降的方式为adam
criterion = nn.CrossEntropyLoss() # 定义交叉熵函数作为损失函数
for epoch in range(150): # 跑150个epoch(一个epoch就是对样本集所有样本的遍历)
runningLoss = 0.0 # 初始化loss
for i, data in enumerate(myTrainLoader, 0): # 枚举loader,写法固定为index,data
inputs, labels = data # data中就是我们刚才定义的__getitem__的顺序
optimizer.zero_grad() # 初始化梯度,必须要有
outputs = net(inputs) # 把data中的样本放入net而不放入标签,得到outputs输出
loss = criterion(nn.Softmax(dim=1)(outputs), labels) # 根据outputs和原有的标签计算交叉熵
loss.backward() # 反向传播计算更新参数,必须要有
optimizer.step() # 更新参数
runningLoss += float(loss.data) # 把一个epoch中的loss更新

print(f'epoch{epoch}:', runningLoss) # 每一个epoch都打印一次loss
lossData.append(runningLoss) # lossData是个我定义的数组,用来把loss的图画出来
runningLoss = 0.0 # loss归0,然后loss在下一个epoch结束的时候再打印
torch.save(net, 'net.pkl') # 保存网络信息,不过这么存可能会有问题,可以换用存dict的方式
print('finish!')

# 训练模型的过程,隐去了X_train和Y_train得到的过程
train(X_train, Y_train) # 如果加上第三个参数,就是对模型的继续训练

# 下面是对网络的调用,是在验证网络
trainedNet = torch.load('./net.pkl') # 加载模型
myTrainLoader = torch.utils.data.DataLoader(MyDataSet(X_train, Y_train)) # 把训练集加载进来
trueNum = 0 # 正确分类的个数
for i, data in enumerate(myTrainLoader, 0): # 还是按照train里面的写法
inputs, labels = data
outputs = int(trainedNet(inputs).data.argmax()) # 获得模型的结果,和label比对
if outputs == int(labels.argmax()):
trueNum += 1
print('train accuracy:', trueNum / 1080) # 这里1080是样本数,可以改成想要的样本个数
myTrainLoader = torch.utils.data.DataLoader(MyDataSet(X_test, Y_test)) # 同上,只不过改成了测试集
trueNum = 0
for i, data in enumerate(myTrainLoader, 0):
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
outputs = int(trainedNet(inputs).data.argmax())
if outputs == int(labels.argmax()):
trueNum += 1
print('test accuracy', trueNum / 120)