提交 ca40ef22 authored 作者: Faruk Ahmed's avatar Faruk Ahmed 提交者: Frederic Bastien

other places

上级 1245e3c6
...@@ -1792,19 +1792,18 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1792,19 +1792,18 @@ class OpWiseCLinker(link.LocalLinker):
# There are ops that don't have _op_use_c_code property # There are ops that don't have _op_use_c_code property
# for example ifelse (or any ops that come with their own # for example ifelse (or any ops that come with their own
# make_thunk # make_thunk
old_value = getattr(node.op, '_op_use_c_code', False) if theano.config.cxx:
try: thunks += [node.op.make_c_thunk(node,
if theano.config.cxx: storage_map,
node.op._op_use_c_code = True compute_map,
no_recycling)]
else:
thunks += [node.op.make_thunk(node, thunks += [node.op.make_thunk(node,
storage_map, storage_map,
compute_map, compute_map,
no_recycling)] no_recycling)]
thunks[-1].inputs = [storage_map[v] for v in node.inputs] thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs]
finally:
node.op._op_use_c_code = old_value
for node in order: for node in order:
if self.allow_gc: if self.allow_gc:
......
...@@ -1044,11 +1044,15 @@ class VM_Linker(link.LocalLinker): ...@@ -1044,11 +1044,15 @@ class VM_Linker(link.LocalLinker):
for node in order: for node in order:
try: try:
if self.c_thunks is False: if self.c_thunks is False:
node.op._op_use_c_code = False thunks.append(node.op.make_py_thunk(node,
thunks.append(node.op.make_thunk(node, storage_map,
storage_map, compute_map,
compute_map, no_recycling))
no_recycling)) else:
thunks.append(node.op.make_thunk(node,
storage_map,
compute_map,
no_recycling))
if not hasattr(thunks[-1], 'lazy'): if not hasattr(thunks[-1], 'lazy'):
# We don't want all ops maker to think about lazy Ops. # We don't want all ops maker to think about lazy Ops.
# So if they didn't specify that its lazy or not, it isn't. # So if they didn't specify that its lazy or not, it isn't.
......
...@@ -6297,15 +6297,10 @@ def constant_folding(node): ...@@ -6297,15 +6297,10 @@ def constant_folding(node):
compute_map[o] = [False] compute_map[o] = [False]
if (hasattr(node.op, 'python_constant_folding') and if (hasattr(node.op, 'python_constant_folding') and
node.op.python_constant_folding(node)): node.op.python_constant_folding(node)):
old_value = getattr(node.op, '_op_use_c_code', False) thunk = node.op.make_py_thunk(node,
try: storage_map,
node.op._op_use_c_code = False compute_map,
thunk = node.op.make_thunk(node, [])
storage_map,
compute_map,
[])
finally:
node.op._op_use_c_code = old_value
else: else:
thunk = node.op.make_thunk(node, storage_map, compute_map, thunk = node.op.make_thunk(node, storage_map, compute_map,
no_recycling=[]) no_recycling=[])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论