t_sne可视化pytorch网络训练前代码和训练后的代码
时间: 2024-06-11 09:04:09 浏览: 204
pytorch训练代码
训练前代码:
```python
import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# 定义模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(784, 512)
self.fc2 = torch.nn.Linear(512, 256)
self.fc3 = torch.nn.Linear(256, 10)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = x.view(-1, 784)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 加载数据
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=64, shuffle=True)
# 创建模型和优化器
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(10):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
# 获取网络输出
outputs = []
labels = []
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data)
outputs.append(output.detach().numpy())
labels.append(target.numpy())
outputs = np.concatenate(outputs, axis=0)
labels = np.concatenate(labels, axis=0)
# 使用t-SNE降维并可视化
tsne = TSNE(n_components=2, perplexity=30, init='pca', n_iter=5000)
outputs_tsne = tsne.fit_transform(outputs)
plt.scatter(outputs_tsne[:, 0], outputs_tsne[:, 1], c=labels)
plt.show()
```
训练后代码:
```python
import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# 定义模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(784, 512)
self.fc2 = torch.nn.Linear(512, 256)
self.fc3 = torch.nn.Linear(256, 10)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = x.view(-1, 784)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 加载数据
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=64, shuffle=True)
# 创建模型和优化器
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 加载已训练好的模型参数
model.load_state_dict(torch.load('model.pth'))
# 获取网络输出
outputs = []
labels = []
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data)
outputs.append(output.detach().numpy())
labels.append(target.numpy())
outputs = np.concatenate(outputs, axis=0)
labels = np.concatenate(labels, axis=0)
# 使用t-SNE降维并可视化
tsne = TSNE(n_components=2, perplexity=30, init='pca', n_iter=5000)
outputs_tsne = tsne.fit_transform(outputs)
plt.scatter(outputs_tsne[:, 0], outputs_tsne[:, 1], c=labels)
plt.show()
```
阅读全文