提交 d5bdbb4a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use x.zeros_like().

上级 186360ce
...@@ -13,7 +13,7 @@ except ImportError: ...@@ -13,7 +13,7 @@ except ImportError:
pass pass
from type import GpuArrayType from type import GpuArrayType
from basic_ops import as_gpuarray_variable, zeros_like from basic_ops import as_gpuarray_variable
class GpuSubtensor(Subtensor): class GpuSubtensor(Subtensor):
def make_node(self, x, *inputs): def make_node(self, x, *inputs):
...@@ -102,7 +102,7 @@ class GpuSubtensor(Subtensor): ...@@ -102,7 +102,7 @@ class GpuSubtensor(Subtensor):
rest = inputs[1:] rest = inputs[1:]
output = self(*inputs) output = self(*inputs)
if output.dtype.find('int') != -1: if output.dtype.find('int') != -1:
first = zeros_like(x, theano.config.floatX) first = x.zeros_like(theano.config.floatX)
else: else:
first = GpuIncSubtensor(self.idx_list)(zeros_like(x), gz, *rest) first = GpuIncSubtensor(self.idx_list)(x.zeros_like(), gz, *rest)
return ([first] + [DisconnectedType()()] * len(rest)) return ([first] + [DisconnectedType()()] * len(rest))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论