请写一个利用CWRU数据做s变换时频图数据集输入resnet18网络的t-sne代码
时间: 2024-03-18 07:41:46 浏览: 149
以下是一个利用CWRU数据做s变换时频图数据集输入resnet18网络的t-sne代码,供您参考:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import normalize
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms, datasets
# Load CWRU dataset
data_transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
train_dataset = datasets.ImageFolder(
root='./cwru_spectrogram/train',
transform=data_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=64, shuffle=False
)
# Load pretrained ResNet18 model
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(512, 10)
model.load_state_dict(torch.load('resnet18_cwru_spectrogram.pth'))
# Get feature vectors from ResNet18
feature_vectors = []
for inputs, labels in train_loader:
inputs = inputs.cuda()
feature_vector = model(inputs).detach().cpu().numpy()
feature_vectors.append(feature_vector)
feature_vectors = np.vstack(feature_vectors)
# Normalize feature vectors
feature_vectors = normalize(feature_vectors, norm='l2')
# Perform t-SNE
tsne = TSNE(n_components=2, perplexity=30, learning_rate=200)
tsne_results = tsne.fit_transform(feature_vectors)
# Plot t-SNE results
plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=train_dataset.targets)
plt.show()
```
请注意,这只是一个示例代码,需要您根据您的实际数据和需求进行修改和调整。
阅读全文