提交 f1e51970 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add dtype arg to zeros_like method of tensor vars

Conflicts: theano/tensor/elemwise.py
上级 2c26f2f5
...@@ -1925,8 +1925,8 @@ class _tensor_py_operators: ...@@ -1925,8 +1925,8 @@ class _tensor_py_operators:
def get_scalar_constant_value(self): def get_scalar_constant_value(self):
return get_scalar_constant_value(self) return get_scalar_constant_value(self)
def zeros_like(model): def zeros_like(model, dtype=None):
return zeros_like(model) return zeros_like(model, dtype=dtype)
class TensorVariable(_tensor_py_operators, Variable): class TensorVariable(_tensor_py_operators, Variable):
......
...@@ -397,8 +397,7 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s); ...@@ -397,8 +397,7 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
# canonicalization optimization phase will remove the inplace. # canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph. # The inplace will be reintroduced automatically later in the graph.
if 'int' in inp[0].dtype: if 'int' in inp[0].dtype:
return [theano.tensor.zeros_like(inp[0], return [inp[0].zeros_like(dtype=theano.config.floatX)]
dtype=theano.config.floatX)]
else: else:
return [DimShuffle(gz.type.broadcastable, grad_order)( return [DimShuffle(gz.type.broadcastable, grad_order)(
Elemwise(scalar.identity)(gz))] Elemwise(scalar.identity)(gz))]
...@@ -622,7 +621,7 @@ class Elemwise(Op): ...@@ -622,7 +621,7 @@ class Elemwise(Op):
for idx, out in enumerate(outs): for idx, out in enumerate(outs):
# make such that _bgrads computes only the gradients of the # make such that _bgrads computes only the gradients of the
# current output on the inputs ( and not all outputs) # current output on the inputs ( and not all outputs)
ograds = [theano.tensor.zeros_like(x) for x in outs] ograds = [x.zeros_like() for x in outs]
ograds[idx] = theano.tensor.ones_like(out) ograds[idx] = theano.tensor.ones_like(out)
bgrads = self._bgrad(inputs, ograds) bgrads = self._bgrad(inputs, ograds)
...@@ -1780,7 +1779,7 @@ class Sum(CAReduceDtype): ...@@ -1780,7 +1779,7 @@ class Sum(CAReduceDtype):
out = self(*inp) out = self(*inp)
if out.dtype.find('int') != -1: if out.dtype.find('int') != -1:
return [theano.tensor.zeros_like(x, dtype=theano.config.floatX)] return [x.zeros_like(dtype=theano.config.floatX)]
gz, = grads gz, = grads
gz = as_tensor_variable(gz) gz = as_tensor_variable(gz)
...@@ -1897,8 +1896,7 @@ class Prod(CAReduceDtype): ...@@ -1897,8 +1896,7 @@ class Prod(CAReduceDtype):
if (out.dtype in discrete_dtypes or if (out.dtype in discrete_dtypes or
self.acc_dtype in discrete_dtypes): self.acc_dtype in discrete_dtypes):
# There is an int conversion in the way # There is an int conversion in the way
return [theano.tensor.zeros_like(prod_in, return [prod_in.zeros_like(dtype=theano.config.floatX)]
dtype=theano.config.floatX)]
# Prepare the broadcasting that is used everywhere to broadcast # Prepare the broadcasting that is used everywhere to broadcast
# over the original groups (ie. broadcast over the elements of a given # over the original groups (ie. broadcast over the elements of a given
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论