提交 f85d6b82 authored 作者: goodfeli's avatar goodfeli

Merge pull request #27 from nouiz/ones_zeros_like_dtype

Add the dtype param to tensor.{zeros,ones}_like as numpy. Test them.
...@@ -2338,17 +2338,19 @@ pprint.assign(fill, printing.FunctionPrinter('fill')) ...@@ -2338,17 +2338,19 @@ pprint.assign(fill, printing.FunctionPrinter('fill'))
@constructor @constructor
def ones_like(model): def ones_like(model, dtype=None):
"""WRITEME""" """equivalent of numpy.ones_like"""
#return Ones(model.type.ndim)(shape(model)) if dtype is None:
ret= fill(model, constant(1.0, dtype=model.type.dtype)) dtype = model.type.dtype
ret= fill(model, constant(1.0, dtype=dtype))
return ret return ret
@constructor @constructor
def zeros_like(model): def zeros_like(model, dtype=None):
"""WRITEME""" """equivalent of numpy.zeros_like"""
#return Zeros(model.type.ndim)(shape(model)) if dtype is None:
return fill(model, constant(0.0, dtype=model.type.dtype)) dtype = model.type.dtype
return fill(model, constant(0.0, dtype=dtype))
def zeros(shape, dtype=config.floatX): def zeros(shape, dtype=config.floatX):
......
...@@ -915,6 +915,15 @@ ErfcInplaceTester = makeBroadcastTester(op = inplace.erfc_inplace, ...@@ -915,6 +915,15 @@ ErfcInplaceTester = makeBroadcastTester(op = inplace.erfc_inplace,
inplace = True, inplace = True,
skip = skip_scipy) skip = skip_scipy)
ZerosLikeTester = makeBroadcastTester(op = zeros_like,
expected = numpy.zeros_like,
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal)
OnesLikeTester = makeBroadcastTester(op = ones_like,
expected = numpy.ones_like,
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal)
DotTester = makeTester(name = 'DotTester', DotTester = makeTester(name = 'DotTester',
op = dot, op = dot,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论