nodes_to_keep = [transpose_257] + list(transpose_257.next_nodes()) AttributeError: next_nodes
时间: 2024-09-25 20:05:47 浏览: 32
出现 `AttributeError: 'NodeProto' object has no attribute 'next_nodes'` 这个错误是因为`NodeProto` 类型的实例(在这个例子中是`transpose_257`)并没有`next_nodes` 属性。在ONNX中,节点之间的连接通常是通过边(EdgeProto)来表示的,而不是直接关联到每个节点。
为了获取节点之间的依赖关系,你需要遍历整个图,找到`Transpose_257` 的输入端口(InputPort),然后找出指向这些输入端口的输出端口(OutputPort)所对应的节点。这里是一个修改过的版本,以便正确地提取节点链:
```python
from onnx import GraphProto, NodeProto, ValueInfoProto
def get_nodes_after(node, graph):
visited = set() # 用来存储已经访问过的节点
stack = [(node, [])] # 初始化栈,第一个元素是节点和空列表
while stack:
current_node, path = stack.pop()
if current_node in visited:
continue
visited.add(current_node)
path.append(current_node) # 添加当前节点到路径
if isinstance(current_node, NodeProto): # 当前节点是运算符节点
# 查找该节点的输入并添加其对应的目标节点
for input_idx in range(len(current_node.input)):
target_node = graph.get_node_by_output_name(current_node.input[input_idx])
if target_node is not None:
stack.append((target_node, path))
elif isinstance(current_node, ValueInfoProto): # 当前节点是值信息,可能是前一个节点的输出
# 检查前一个节点是否为运算符节点,如果是,则追加其输出
last_operator = path[-1]
if isinstance(last_operator, NodeProto):
for output_port in last_operator.output:
target_node = graph.get_node_by_output_name(output_port)
if target_node is not None:
stack.append((target_node, path))
return path
nodes_to_keep = get_nodes_after(transpose_257, original_model.graph)
```
这个修改后的代码会返回从`Transpose_257` 开始的一系列节点。然后你可以按需选择哪些节点加入到`nodes_to_keep` 中。
阅读全文