import os from random import shuffle import pickle import librosa import tensorflow.keras as keras import librosa.display import numpy as np from matplotlib import pyplot as plt def load_files(audio_dir): files=os.listdir(audio_dir) wav_files=[] for wav in files: if not wav.endswith('.wav'): continue wav_files.append(wav) if not wav_files: print('未找到数据集') shuffle(wav_files) nfiles=len(wav_files) ntrain=int(nfiles*0.7) return wav_files[:ntrain],wav_files[ntrain:] audio_dir='../data/recordings/' dataset_pickle='../tmp/recordings.pkl' train_files,valid_files=load_files(audio_dir) print('训练集样本数为{}\n验证集样本数为{}'.format(len(train_files),load_files(valid_files)))帮我检查错误
时间: 2024-02-17 14:59:45 浏览: 100
1_notmnist.rar_notMNIST.pickle_tensorflow_tensorflow数据_基于python的
这段代码中有两个问题:
1. 在函数 `load_files()` 中,`wav_files` 列表只会包含以 `.wav` 结尾的文件名,因为 `if not wav.endswith('.wav')` 和 `wav_files.append(wav)` 没有缩进到同一个代码块中。你需要将两行代码缩进到同一级别,才能正确的将所有 `.wav` 文件名添加到 `wav_files` 列表中。
2. 在 `print()` 函数中使用了 `load_files()` 函数调用的返回值,而 `load_files()` 函数返回的是两个列表类型的对象,因此会报错。你需要将 `len(load_files(valid_files))` 改为 `len(valid_files)` 才能正确输出验证集样本数。
下面是已经修改过后的代码:
```python
import os
from random import shuffle
import pickle
import librosa
import tensorflow.keras as keras
import librosa.display
import numpy as np
from matplotlib import pyplot as plt
def load_files(audio_dir):
files = os.listdir(audio_dir)
wav_files = []
for wav in files:
if not wav.endswith('.wav'):
continue
wav_files.append(wav)
if not wav_files:
print('未找到数据集')
shuffle(wav_files)
n_files = len(wav_files)
n_train = int(n_files * 0.7)
return wav_files[:n_train], wav_files[n_train:]
audio_dir = '../data/recordings/'
dataset_pickle = '../tmp/recordings.pkl'
train_files, valid_files = load_files(audio_dir)
print('训练集样本数为{}\n验证集样本数为{}'.format(len(train_files), len(valid_files)))
```
希望能帮到你!
阅读全文