提交 4550efa6 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5532 from nouiz/interface_cleanup

Interface addition and cleanup
...@@ -1339,7 +1339,7 @@ def _float_ones_like(x): ...@@ -1339,7 +1339,7 @@ def _float_ones_like(x):
if dtype not in tensor.float_dtypes: if dtype not in tensor.float_dtypes:
dtype = theano.config.floatX dtype = theano.config.floatX
return tensor.ones_like(x, dtype=dtype) return x.ones_like(dtype=dtype)
class numeric_grad(object): class numeric_grad(object):
......
...@@ -790,6 +790,12 @@ class _scalar_py_operators: ...@@ -790,6 +790,12 @@ class _scalar_py_operators:
dtype = str(self.type.dtype) dtype = str(self.type.dtype)
return second(self, ScalarConstant(get_scalar_type(dtype), 0)) return second(self, ScalarConstant(get_scalar_type(dtype), 0))
def ones_like(self, dtype=None):
# The second is needed for Elemwise ops to work right
if dtype is None:
dtype = str(self.type.dtype)
return second(self, ScalarConstant(get_scalar_type(dtype), 1))
def astype(self, dtype): def astype(self, dtype):
return cast(self, dtype) return cast(self, dtype)
......
...@@ -1822,8 +1822,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1822,8 +1822,7 @@ class ScanMerge(gof.Optimizer):
len(rep.inputs) != len(node.inputs) or len(rep.inputs) != len(node.inputs) or
len(rep.outputs) != len(node.outputs) or len(rep.outputs) != len(node.outputs) or
node.op.truncate_gradient != rep.op.truncate_gradient or node.op.truncate_gradient != rep.op.truncate_gradient or
node.op.mode != rep.op.mode or node.op.mode != rep.op.mode):
rep.op.as_while != node.op.as_while):
return False return False
nsteps = node.inputs[0] nsteps = node.inputs[0]
......
...@@ -731,6 +731,9 @@ class _tensor_py_operators(object): ...@@ -731,6 +731,9 @@ class _tensor_py_operators(object):
def zeros_like(model, dtype=None): def zeros_like(model, dtype=None):
return theano.tensor.basic.zeros_like(model, dtype=dtype) return theano.tensor.basic.zeros_like(model, dtype=dtype)
def ones_like(model, dtype=None):
return theano.tensor.basic.ones_like(model, dtype=dtype)
def cumsum(self, axis=None): def cumsum(self, axis=None):
return theano.tensor.extra_ops.cumsum(self, axis) return theano.tensor.extra_ops.cumsum(self, axis)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论