def affine_forward(x, w, b): out = None N = x.shape[0] x_row = x.reshape(N, -1) # (N,D) -1表示不知道多少列,指定行,就能算出列 = 2 * 32 * 16 * 16/2 = 8192 out = np.dot(x_row, w) + b # (N,M) 2*8192 8192*100 =>2 * 100 cache = (x, w, b) return out, cache为什么 输入x: 2*32*16*16 输入x_row: 2*8192 超参w:8192*100 输出:矩阵乘法 2*8192 ->8192*100 =>2*100
时间: 2023-12-22 13:04:34 浏览: 66
在这段代码中,输入x的形状为(2,32,16,16),经过reshape之后,x_row的形状为(2,8192),这是因为32*16*16=8192。超参w的形状为(8192,100),因此经过矩阵乘法x_row.dot(w),得到的输出形状为(2,100)。最后加上偏置b,得到的输出仍然是形状为(2,100)的矩阵。因此代码的实现是正确的。
相关问题
def affine_forward(x, w, b): out = None N = x.shape[0] x_row = x.reshape(N, -1) out = np.dot(x_row, w) + b.reshape(1, -1) cache = (x, w, b) return out, cacheshapes (40,6272) and (784,100) not aligned: 6272 (dim 1) != 784 (dim 0)报错,当修改为def affine_forward(x, w, b): out = None N = x.shape[0] x_row = x.reshape(N, -1) w_row = w.reshape(w.shape[0], -1).T out = np.dot(x_row, w_row) + b.reshape(1, -1) cache = (x, w, b) return out, cache时,shapes (40,6272) and (100,8192) not aligned: 6272 (dim 1) != 100 (dim 0)报错
这个错误发生在第4行的 np.dot(x_row, w_row) 上,因为x_row的第二个维度是6272,而w_row的第一个维度是100,两个维度不相等,无法进行矩阵乘法。根据错误提示,w的形状是(100,8192),应该将x的第二个维度重新设置为8192才能进行矩阵乘法。修改代码如下:
```
def affine_forward(x, w, b):
out = None
N = x.shape[0]
x_row = x.reshape(N, -1)
w_row = w.reshape(w.shape[0], -1).T
out = np.dot(x_row, w_row) + b.reshape(1, -1)
cache = (x, w, b)
return out, cache
```
全连接层:矩阵变换,获取对应目标相同的行与列 输入x: 2321616 输入x_row: 28192 超参w:8192100 输出:矩阵乘法 28192 ->8192100 =>2100def affine_forward(x, w, b): out = None N = x.shape[0] x_row = x.reshape(N, -1) out = np.dot(x_row, w) + b.reshape(1, -1) cache = (x, w, b) return out, cache def affine_relu_forward(x, w, b): x = x.reshape(x.shape[0], -1) w = np.random.randn(784, 100) b = np.zeros((1, 100)) out, cache = layers.affine_forward(x, w, b) a, fc_cache = layers.affine_forward(x, w, b) out, relu_cache = layers.relu_forward(a) cache = (fc_cache, relu_cache) return out, cacheValueError: shapes (40,6272) and (784,100) not aligned: 6272 (dim 1) != 784 (dim 0)
这段代码中的注释不太清楚,但是看到输入x的维度为2321616,经过reshape变换后变成了28192行,这个28192可能是由于将输入x分成了28192个样本,每个样本的维度为82。而超参w的维度为8192100,这个维度可能是由于将输入x的每个样本与w的每一列进行矩阵乘法得到一个长度为2100的向量,所以w的列数为2100。
至于输出的维度为2100,这是因为输入x的每个样本与w的每一列进行矩阵乘法得到一个长度为2100的向量,然后将这些向量拼接起来得到一个28192行、2100列的矩阵。最后再加上偏置项b,得到的是一个28192行、2100列的矩阵。
阅读全文