提交 de39392d authored 作者: Frederic Bastien's avatar Frederic Bastien

fix test: now zeros_like and ones_like return the same type as the model

上级 3d01f413
......@@ -33,7 +33,8 @@ class ConvOp(Op):
unroll_kern=4,
imshp_logical=None,
kshp_logical=None,
kshp_logical_top_aligned=True):
kshp_logical_top_aligned=True,
version=-1):
"""
......@@ -46,7 +47,7 @@ class ConvOp(Op):
out_mode - 'valid', 'full'
unroll_batch - c code generation option
unroll_kern - c code generation option
version - passed to GpuConv.
The reason that this op does the summation over convolutions within the 'stack' is that
it allows us to be memory-efficient about how gradients are calculated. If, for
......@@ -77,6 +78,7 @@ class ConvOp(Op):
self.bsize=bsize
self.dx=dx
self.dy=dy
self.version=version
# a triple
self.imshp_logical = self.imshp if imshp_logical is None else tuple(imshp_logical)
assert len(self.imshp) == len(self.imshp_logical)
......
......@@ -133,7 +133,7 @@ _as_tensor_variable = as_tensor_variable
as_tensor = as_tensor_variable
def constant_or_value(x, rtype, name=None, ndim=None):
def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
"""Return a symbolic `Constant` with value `x`
:Exceptions:
......@@ -141,6 +141,9 @@ def constant_or_value(x, rtype, name=None, ndim=None):
- `ValueError`: `x` could not be expanded to have ndim dimensions
"""
if dtype is not None:
x_ = numpy.asarray(x, dtype=dtype)
else:
x_ = None
if rtype is TensorConstant and isinstance(x, int):
for dtype in ['int8', 'int16', 'int32', 'int64']:
......@@ -175,11 +178,11 @@ def constant_or_value(x, rtype, name=None, ndim=None):
except:
raise TypeError("Could not convert %s to TensorType" % x, type(x))
def constant(x, name=None, ndim=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim)
def constant(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim, dtype=dtype)
def value(x, name=None, ndim=None):
return constant_or_value(x, rtype=TensorValue, name=name, ndim=ndim)
def value(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorValue, name=name, ndim=ndim, dtype=dtype)
def _obj_is_wrappable_as_tensor(x):
try:
......@@ -1234,13 +1237,14 @@ pprint.assign(fill, printing.FunctionPrinter('fill'))
def ones_like(model):
"""WRITEME"""
#return Ones(model.type.ndim)(shape(model))
return fill(model, 1.0)
ret= fill(model, constant(1.0, dtype=model.type.dtype))
return ret
@constructor
def zeros_like(model):
"""WRITEME"""
#return Zeros(model.type.ndim)(shape(model))
return fill(model, 0.0)
return fill(model, constant(0.0, dtype=model.type.dtype))
class Filler(gof.Op):
"""WRITEME"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论