提交 ca0a627d authored 作者: Frederic Bastien's avatar Frederic Bastien

Updates following code review.

上级 792d5d2d
......@@ -125,13 +125,13 @@ class OpFromGraph(gof.Op):
[type() for type in self.output_types])
def prepare_node(self, node, storage_map, compute_map, impl):
if not hasattr(node.tag, "fn") and impl == 'py':
node.tag.fn = orig_function(self.new_inputs,
self.new_outputs,
**self.kwargs)
if not hasattr(self, "fn") and impl == 'py':
self.fn = orig_function(self.new_inputs,
self.new_outputs,
**self.kwargs)
def perform(self, node, inputs, outputs):
variables = node.tag.fn(*inputs)
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
# TODO: when function's output-borrowing semantics are correct,
......
......@@ -55,6 +55,14 @@ class GpuCumsum(CumsumOp, GpuOp):
if node_.op.max_threads_dim0 is None or node_.op.max_grid_size1 is None or node_.op.max_grid_size2 is None:
cuda = theano.sandbox.cuda
device_id = cuda.use.device_number
if device_id is None:
cuda.use("gpu",
force=False,
default_to_move_computation_to_gpu=False,
move_shared_float32_to_gpu=False,
enable_cuda=False,
test_driver=True)
device_id = cuda.use.device_number
cuda_ndarray = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray
prop = cuda_ndarray.device_properties(device_id)
node_.op.max_threads_dim0 = prop['maxThreadsDim0']
......
......@@ -3663,14 +3663,15 @@ class Composite(ScalarOp):
# Postpone the creation in case it isn't needed.
# self.init_name() # self.name
self.name = None
self.prepare_node_called = set()
def prepare_node(self, node, storage_map, compute_map, impl):
if impl == 'py':
self.init_py_impls() # self._impls
if not getattr(node.tag, 'graph_prepare_node_called', False):
if impl not in self.prepare_node_called:
for n in theano.gof.graph.list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl)
node.tag.graph_prepare_node_called = True
self.prepare_node_called.add(impl)
def output_types(self, input_types):
if tuple(input_types) != self.inputs_type:
......
......@@ -895,6 +895,9 @@ second dimension
# It happen that make_thunk isn't called, like in
# get_scalar_constant_value
self.prepare_node(node, None, None, 'py')
# prepare_node will add ufunc to self or the tag
# depending if we can reuse it or not. So we need to
# test both again.
if self.ufunc:
ufunc = self.ufunc
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论