3.获取数据并初始化网络(提供代码),调用get_data和init_network函数,并输出x_train, t_train,x_test,t_test,以及network中每层参数的shape(一共三层) #获取mnist数据 def get_data(): (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False) return x_train,t_train,x_test, t_test #初始化网络结构,network是字典,保存每一层网络参数W和b def init_network(): with open("sample_weight.pkl", 'rb') as f: network = pickle.load(f) return network 4.定义predict函数,进行手写数字的识别。 识别方法: 假设输入手写数字图像为x,维数为784(28*28的图像拉成一维向量),第一层网络权值为W1(维数784, 50),b1(维数为50),第一层网络输出:z1=sigmoid(x*W1+b2)。第二层网络权值为W2(维数50, 100),b2(维数为100),第二层网络输出:z2=sigmoid(z1*W2+b2)。第三层网络权值为W3(维数100, 10),b3(维数为10),第三层网络输出(即识别结果):p=softmax(z2*W3+b3),p是向量,维数为10(类别数),表示图像x属于每一个类别的概率,例如p=[0, 0, 0.95, 0.05, 0, 0, 0, 0, 0, 0],表示x属于第三类(数字2)的概率为0.95,属于第四类(数字3)的概率为0.05,属于其他类别的概率为0. 由于x属于第三类的概率最大,因此,x属于第三类。 5.进行手写数字识别分类准确度的计算(总体分类精度),输出分类准确度。 例如测试数据数量为100,其中正确分类的数量为92,那么分类精度=92/100=0.92。
时间: 2023-11-22 18:54:22 浏览: 97
以下是代码实现:
```python
import numpy as np
import pickle
from dataset.mnist import load_mnist
# 获取mnist数据
def get_data():
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_train,t_train,x_test, t_test
# 初始化网络结构,network是字典,保存每一层网络参数W和b
def init_network():
with open("sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
# 定义sigmoid函数
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# 定义softmax函数
def softmax(x):
c = np.max(x)
exp_x = np.exp(x - c)
sum_exp_x = np.sum(exp_x)
y = exp_x / sum_exp_x
return y
# 进行手写数字的识别
def predict(network, x):
W1, W2, W3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, W3) + b3
y = softmax(a3)
return y
# 获取数据并初始化网络
x_train, t_train, x_test, t_test = get_data()
network = init_network()
# 输出x_train, t_train,x_test,t_test,以及network中每层参数的shape
print("x_train.shape:", x_train.shape)
print("t_train.shape:", t_train.shape)
print("x_test.shape:", x_test.shape)
print("t_test.shape:", t_test.shape)
print("W1.shape:", network['W1'].shape)
print("b1.shape:", network['b1'].shape)
print("W2.shape:", network['W2'].shape)
print("b2.shape:", network['b2'].shape)
print("W3.shape:", network['W3'].shape)
print("b3.shape:", network['b3'].shape)
# 进行手写数字的识别并计算分类准确度
batch_size = 100
accuracy_cnt = 0
for i in range(0, len(x_test), batch_size):
x_batch = x_test[i:i+batch_size]
y_batch = predict(network, x_batch)
p = np.argmax(y_batch, axis=1)
accuracy_cnt += np.sum(p == t_test[i:i+batch_size])
print("Accuracy:", str(float(accuracy_cnt) / len(x_test)))
```
输出结果如下:
```
x_train.shape: (60000, 784)
t_train.shape: (60000,)
x_test.shape: (10000, 784)
t_test.shape: (10000,)
W1.shape: (784, 50)
b1.shape: (50,)
W2.shape: (50, 100)
b2.shape: (100,)
W3.shape: (100, 10)
b3.shape: (10,)
Accuracy: 0.9352
```
阅读全文