提交 a471fa65 authored 作者: Frederic's avatar Frederic

Add the dtype param to tensor.{zeros,ones}_like as numpy. Test them.

上级 bca3e403
......@@ -2338,17 +2338,19 @@ pprint.assign(fill, printing.FunctionPrinter('fill'))
@constructor
def ones_like(model):
"""WRITEME"""
#return Ones(model.type.ndim)(shape(model))
ret= fill(model, constant(1.0, dtype=model.type.dtype))
def ones_like(model, dtype=None):
"""equivalent of numpy.ones_like"""
if dtype is None:
dtype = model.type.dtype
ret= fill(model, constant(1.0, dtype=dtype))
return ret
@constructor
def zeros_like(model):
"""WRITEME"""
#return Zeros(model.type.ndim)(shape(model))
return fill(model, constant(0.0, dtype=model.type.dtype))
def zeros_like(model, dtype=None):
"""equivalent of numpy.zeros_like"""
if dtype is None:
dtype = model.type.dtype
return fill(model, constant(0.0, dtype=dtype))
def zeros(shape, dtype=config.floatX):
......
......@@ -915,6 +915,15 @@ ErfcInplaceTester = makeBroadcastTester(op = inplace.erfc_inplace,
inplace = True,
skip = skip_scipy)
ZerosLikeTester = makeBroadcastTester(op = zeros_like,
expected = numpy.zeros_like,
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal)
OnesLikeTester = makeBroadcastTester(op = ones_like,
expected = numpy.ones_like,
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal)
DotTester = makeTester(name = 'DotTester',
op = dot,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论