提交 de2689f2 authored 作者: Roy Xue's avatar Roy Xue

update:

Add TensorConstant check in check_node_state method.
上级 441f01cf
...@@ -24,6 +24,7 @@ from collections import defaultdict ...@@ -24,6 +24,7 @@ from collections import defaultdict
import numpy import numpy
import theano import theano
from theano.gof import Constant
from theano.configparser import AddConfigVar, BoolParam, IntParam from theano.configparser import AddConfigVar, BoolParam, IntParam
...@@ -702,9 +703,12 @@ class ProfileStats(object): ...@@ -702,9 +703,12 @@ class ProfileStats(object):
compute_map = defaultdict(lambda: [0]) compute_map = defaultdict(lambda: [0])
# compute_map use to check if a node is valid # compute_map use to check if a node is valid
for node in node_list: # for node in node_list:
for val in node.inputs: # for val in node.inputs:
compute_map[val][0] = 1 # compute_map[val][0] = 1
for node in fgraph.inputs:
compute_map[node][0] = 1
print fgraph.outputs
def check_node_state(node): def check_node_state(node):
""" """
...@@ -715,6 +719,9 @@ class ProfileStats(object): ...@@ -715,6 +719,9 @@ class ProfileStats(object):
inputs = node.inputs inputs = node.inputs
outputs = node.outputs outputs = node.outputs
deps = inputs + node.destroy_dependencies deps = inputs + node.destroy_dependencies
for node in deps:
if isinstance(node, Constant):
compute_map[node][0] = 1
computed_ins = all(compute_map[v][0] for v in inputs) computed_ins = all(compute_map[v][0] for v in inputs)
computed_outs = all(compute_map[v][0] for v in outputs) computed_outs = all(compute_map[v][0] for v in outputs)
# check if there could be a compute_map # check if there could be a compute_map
...@@ -736,6 +743,9 @@ class ProfileStats(object): ...@@ -736,6 +743,9 @@ class ProfileStats(object):
for i in range(len(node_list)): for i in range(len(node_list)):
v = node_list[i:i+1] v = node_list[i:i+1]
if check_node_state(v[0]): if check_node_state(v[0]):
# print v[0].inputs
# print v[0].outputs
# print compute_map
for node in v[0].outputs: for node in v[0].outputs:
compute_map[node][0] = 1 compute_map[node][0] = 1
if len(node_list) == 1: if len(node_list) == 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论