tmp_d = {} r2_get = [] for i, feat in enumerate(data_history_conversion): tmp_d[i] = np.array([[item] for item in feat[i]]) r2 = r2_score(tmp_d[i], data_pre_conversion) r2_get.append(r2) sort_r2 = heapq.nlargest(5, r2_get) r2_index = [r2_get.index(x) for x in sort_r2] 代码优化
时间: 2024-01-20 18:03:59 浏览: 96
可以优化的地方如下:
1. 在循环中,每次都创建一个新的字典项并添加到字典中。这种方式比较消耗资源,可以在循环外部直接创建一个空字典,然后在循环中更新它。
2. 每次使用 r2_score 函数计算 r2 值时,都需要将数据转换成 numpy 数组。可以在循环外部将 data_pre_conversion 转换成 numpy 数组,这样就可以避免在循环中进行多次转换了。
3. 为了找到前五个最大的 r2 值,使用了 heapq.nlargest 函数和列表索引。可以直接使用 numpy.argsort 函数来获取按照从大到小排序后的索引值,然后取前五个即可。
修改后的代码如下:
```python
tmp_d = {}
data_pre_conversion = np.array(data_pre_conversion)
for i, feat in enumerate(data_history_conversion):
tmp_d[i] = np.array([[item] for item in feat[i]])
r2 = r2_score(tmp_d[i], data_pre_conversion)
r2_get.append(r2)
sort_r2_index = np.argsort(r2_get)[::-1][:5]
r2_index = [i for i in sort_r2_index]
```
这样可以避免重复转换数据和创建多余的字典项,同时也使用了 numpy.argsort 函数来简化代码。
相关问题
data.fillna(method='ffill', inplace=True) date_history,data_history = pd.DataFrame(data.iloc[:, 0]) data_history = pd.DataFrame(data.iloc[:, 1]) date_history = np.array(date_history) data_history = [x for item in np.array(data_history).tolist() for x in item] # 缺失值处理 history_time_list = [] for date in date_history: date_obj = datetime.datetime.strptime(date[0], '%Y/%m/%d %H:%M') #将字符串转为 datetime 对象 history_time_list.append(date_obj) start_time = history_time_list[0] # 起始时间 end_time = history_time_list[-1] # 结束时间 delta = datetime.timedelta(minutes=15) #时间间隔为15分钟 time_new_list = [] current_time = start_time while current_time <= end_time: time_new_list.append(current_time) current_time += delta # 缺失位置记录 code_list = [] for i in range(len(time_new_list)): code_list = code_list history_time_list = history_time_list while (time_new_list[i] - history_time_list[i]) != datetime.timedelta(minutes=0): history_time_list.insert(i, time_new_list[i]) code_list.append(i) for i in code_list: data_history.insert(i, data_history[i - 1]) # 输出补充好之后的数据 data = pd.DataFrame({'date': time_new_list, 'load': data_history}) return data 代码优化
以下是对代码的优化:
```python
data.fillna(method='ffill', inplace=True)
date_history, data_history = data.iloc[:, 0], data.iloc[:, 1:].values.flatten()
date_history = np.array([datetime.datetime.strptime(date, '%Y/%m/%d %H:%M') for date in date_history])
start_time, end_time = date_history[0], date_history[-1]
delta = datetime.timedelta(minutes=15)
time_new_list = []
current_time = start_time
while current_time <= end_time:
time_new_list.append(current_time)
current_time += delta
code_list = [i for i, date in enumerate(date_history) if date not in time_new_list]
for i in code_list:
data_history = np.insert(data_history, i, data_history[i - 1])
data = pd.DataFrame({'date': time_new_list, 'load': data_history})
return data
```
代码优化的主要思路是:
1. 将第二列数据展平成一维数组,避免后续操作需要用到的循环。
2. 将日期字符串转换为 datetime 对象。
3. 使用列表推导式生成时间序列。
4. 使用列表推导式生成缺失位置列表。
5. 使用 NumPy 的 `insert()` 方法在数据中插入缺失值。
6. 最后将补充好的数据转换为 DataFrame 返回。
这样可以使代码更加简洁、高效,并且减少不必要的循环。
def freq_domain_loss(s_hat, gt_spec, combination=True):n_src = len(s_hat) idx_list = [i for i in range(n_src)] inferences = [] refrences = [] for i, s in enumerate(s_hat): inferences.append(s) refrences.append(gt_spec[..., 2 * i : 2 * i + 2, :]) assert inferences[0].shape == refrences[0].shape _loss_mse = 0.0 cnt = 0.0 for i in range(n_src): _loss_mse += singlesrc_mse(inferences[i], refrences[i]).mean() cnt += 1.0 # If Combination is True, calculate the expected combinations. if combination: for c in range(2, n_src): patterns = list(itertools.combinations(idx_list, c)) for indices in patterns: tmp_loss = singlesrc_mse( sum(itemgetter(*indices)(inferences)), sum(itemgetter(*indices)(refrences)), ).mean() _loss_mse += tmp_loss cnt += 1.0 _loss_mse /= cnt return _loss_mse
这个函数实现了一个频域损失函数,用于衡量音频信号重构的准确度。它的输入参数包括一个音频信号的估计值 s_hat 和对应的参考值 gt_spec,以及一个布尔值 combination,用于指定是否考虑多个信号源的情况。如果 combination 为 True,则会计算所有可能的信号源组合情况下的损失值,否则只计算单个信号源的损失值。该函数首先将输入的 s_hat 和 gt_spec 分别按照源数量进行拆分,然后分别计算每个源信号与其参考信号之间的均方误差(MSE)并累加起来。如果 combination 为 True,则还会计算所有可能的源信号组合情况下的 MSE 并进行累加。最后,函数将累加后的 MSE 值除以源数量得到平均值,并将其作为输出返回。
阅读全文
相关推荐
![-](https://img-home.csdnimg.cn/images/20241231044736.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)