def train(device, model, opt, loss_fn, train_loader): model.train() epoch_loss = 0 f1 = [] for g in train_loader: g = g.to(device) feat = g.ndata['feat'] label = g.ndata['label'] logits = model(g, feat) loss = loss_fn(logits, label) f1.append(get_f1(logits.detach().cpu().numpy(), label.detach().cpu().numpy())) epoch_loss += loss.data.item() opt.zero_grad() loss.backward() opt.step() return epoch_loss / len(train_loader), np.mean(f1),写一个train_loader以调用该函数
时间: 2023-12-10 22:03:16 浏览: 129
graph_opt.pb
这里提供一个简单的train_loader示例:
```
from torch.utils.data import Dataset, DataLoader
class GraphDataset(Dataset):
def __init__(self, graphs, targets):
self.graphs = graphs
self.targets = targets
def __len__(self):
return len(self.graphs)
def __getitem__(self, idx):
return self.graphs[idx], self.targets[idx]
train_dataset = GraphDataset(train_graphs, train_targets)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
```
其中,`train_graphs` 是一个包含多个图的列表,而 `train_targets` 是相应的目标标签。`GraphDataset` 类将每个图和其对应的标签打包成一个数据样本,并且通过 `DataLoader` 对象以批量的方式进行训练。在训练时,可以使用 `train()` 函数进行模型训练和评估。
阅读全文