训练模型的 main graph,如何查看?
时间: 2023-06-04 08:07:31 浏览: 94
您可以使用Tensorboard来查看训练模型的main graph。Tensorboard是一个可视化工具,可以帮助您直观地了解模型训练的各个方面。在Tensorboard中,您可以查看模型的图表(包括main graph),张量的值和梯度等信息。如果您已经在您的代码中使用了Tensorflow,那么您可以通过以下命令启动Tensorboard,并在浏览器中查看和分析您的模型:
```
tensorboard --logdir=path/to/log-directory
```
其中,path/to/log-directory是您保存训练数据的目录。在Tensorboard中,您可以在"Graphs"选项卡下查看main graph。
相关问题
猫狗训练模型代码实例
这篇博客使用TensorFlow框架,使用预训练模型进行猫狗分类。代码在Github上已经公开,可以从[这里](https://github.com/jmhIcoding/dogsVScats)获取。
使用预训练模型进行微调的代码如下,其中包括了数据处理、模型构建和训练三个部分:
数据处理部分[^1]:
```python
import tensorflow as tf
slim = tf.contrib.slim
...
def get_dataset(dataset_name, split_name, dataset_dir, file_pattern):
"""
获取指定数据集和指定数据集中的数据切分
"""
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
if dataset_name == 'imagenet':
return dataset.get_split(split_name, dataset_dir, file_pattern)
elif dataset_name == 'flowers':
return flowers.get_split(split_name, dataset_dir, file_pattern)
elif dataset_name == 'cifar10':
return cifar10.get_split(split_name, dataset_dir, file_pattern)
elif dataset_name == 'mnist':
return mnist.get_split(split_name, dataset_dir, file_pattern)
elif dataset_name == 'cats_vs_dogs':
return dogs.get_split(split_name, dataset_dir, file_pattern)
else:
raise ValueError('Invalid dataset name %s.' % dataset_name)
def load_batch(dataset, batch_size, height, width, is_training=True):
"""
加载一批数据
"""
data_provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=is_training,
common_queue_capacity=2 * batch_size,
common_queue_min=batch_size)
image_raw, label = data_provider.get(['image', 'label'])
image = inception_preprocessing.preprocess_image(
image_raw,
height,
width,
is_training=is_training)
images, labels = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=4,
capacity=5 * batch_size)
return images, labels
```
模型构建部分:
```python
import tensorflow as tf
slim = tf.contrib.slim
...
def build_model(inputs, num_classes, is_training=True, scope='vgg_16'):
"""
构建VGG16模型
"""
with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
end_points_collection = sc.name + '_end_points'
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
outputs_collections=end_points_collection):
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
...
net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5')
net = slim.max_pool2d(net, [2, 2], scope='pool5')
net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
net = slim.dropout(net, 0.5, is_training=is_training,
scope='dropout6')
net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
net = slim.dropout(net, 0.5, is_training=is_training,
scope='dropout7')
net = slim.conv2d(net, num_classes, [1, 1],
activation_fn=None,
normalizer_fn=None,
scope='fc8')
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
return net, end_points
```
训练部分:
```python
import tensorflow as tf
slim = tf.contrib.slim
...
def run_training(dataset_name, train_dir, dataset_dir, num_classes=2, batch_size=32, num_epochs=10, initial_learning_rate=0.0001):
"""
训练模型
"""
with tf.Graph().as_default():
tf.logging.set_verbosity(tf.logging.INFO)
# 获取数据集
dataset = get_dataset(dataset_name, 'train', dataset_dir, '%s_*.tfrecord')
images, labels = load_batch(dataset, batch_size=batch_size, height=224, width=224, is_training=True)
# 构建网络
logits, end_points = build_model(images, num_classes=num_classes, is_training=True)
# 定义损失函数
one_hot_labels = slim.one_hot_encoding(labels, num_classes)
slim.losses.softmax_cross_entropy(logits=logits, onehot_labels=one_hot_labels)
total_loss = slim.losses.get_total_loss()
# 定义优化器
global_step = tf.train.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(
initial_learning_rate,
global_step,
decay_steps=1000,
decay_rate=0.96,
staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = slim.learning.create_train_op(total_loss, optimizer, global_step=global_step)
# 进行训练
saver = tf.train.Saver(tf.global_variables())
slim.learning.train(
train_op,
train_dir,
log_every_n_steps=1,
save_summaries_secs=20,
saver=saver,
number_of_steps=num_epochs * dataset.num_samples // batch_size,
save_interval_secs=120)
if __name__ == '__main__':
run_training('cats_vs_dogs', '/tmp/cats_vs_dogs', '/path/to/dataset')
```
多个graph训练VGAE代码示例
下面是一个多个graph训练VGAE的代码示例:
```python
import dgl
import torch
import torch.nn.functional as F
from dgl.nn import GraphConv
from dgl.data import MiniGCDataset
class VGAE(torch.nn.Module):
def __init__(self, in_feats, hidden_size):
super(VGAE, self).__init__()
self.conv1 = GraphConv(in_feats, hidden_size)
self.conv2 = GraphConv(hidden_size, hidden_size)
self.mean_fc = torch.nn.Linear(hidden_size, hidden_size)
self.logstd_fc = torch.nn.Linear(hidden_size, hidden_size)
def forward(self, g):
h = g.ndata['feat']
h = F.relu(self.conv1(g, h))
h = F.relu(self.conv2(g, h))
mean = self.mean_fc(h)
logstd = self.logstd_fc(h)
return mean, logstd
def train(dataset):
model = VGAE(10, 5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(10):
for graph, _ in dataset:
optimizer.zero_grad()
h = graph.ndata['feat']
mean, logstd = model(graph)
z = mean + torch.randn([graph.num_nodes(), 5]) * torch.exp(logstd)
recon = torch.sigmoid(torch.matmul(z, z.t()))
loss = -torch.mean(
torch.sum(graph.adjacency_matrix().to_dense() * torch.log(recon)
+ (1 - graph.adjacency_matrix().to_dense()) * torch.log(1 - recon), dim=1))
kl_loss = -0.5 * torch.mean(torch.sum(1 + 2 * logstd - mean.pow(2) - torch.exp(logstd).pow(2), dim=1))
loss += kl_loss
loss.backward()
optimizer.step()
if __name__ == '__main__':
dataset = MiniGCDataset(80, 10, 20)
graphs = []
for graph, _ in dataset:
graphs.append(graph)
batched_graph = dgl.batch(graphs)
train(batched_graph)
```
在这个示例中,我们首先定义了一个VGAE模型,然后在`train()`函数中,我们对数据集中的每个图进行了训练。注意,这里我们将所有图合并成一个大图,然后在训练时使用这个大图进行训练。这里的损失函数包括重构误差和KL散度,用于衡量生成的图与原始图之间的相似度。最后,我们将所有图合并到一个大图中,然后使用`train()`函数进行训练。
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)