提交 de39392d authored 作者: Frederic Bastien's avatar Frederic Bastien

fix test: now zeros_like and ones_like return the same type as the model

上级 3d01f413
...@@ -33,7 +33,8 @@ class ConvOp(Op): ...@@ -33,7 +33,8 @@ class ConvOp(Op):
unroll_kern=4, unroll_kern=4,
imshp_logical=None, imshp_logical=None,
kshp_logical=None, kshp_logical=None,
kshp_logical_top_aligned=True): kshp_logical_top_aligned=True,
version=-1):
""" """
...@@ -46,7 +47,7 @@ class ConvOp(Op): ...@@ -46,7 +47,7 @@ class ConvOp(Op):
out_mode - 'valid', 'full' out_mode - 'valid', 'full'
unroll_batch - c code generation option unroll_batch - c code generation option
unroll_kern - c code generation option unroll_kern - c code generation option
version - passed to GpuConv.
The reason that this op does the summation over convolutions within the 'stack' is that The reason that this op does the summation over convolutions within the 'stack' is that
it allows us to be memory-efficient about how gradients are calculated. If, for it allows us to be memory-efficient about how gradients are calculated. If, for
...@@ -77,6 +78,7 @@ class ConvOp(Op): ...@@ -77,6 +78,7 @@ class ConvOp(Op):
self.bsize=bsize self.bsize=bsize
self.dx=dx self.dx=dx
self.dy=dy self.dy=dy
self.version=version
# a triple # a triple
self.imshp_logical = self.imshp if imshp_logical is None else tuple(imshp_logical) self.imshp_logical = self.imshp if imshp_logical is None else tuple(imshp_logical)
assert len(self.imshp) == len(self.imshp_logical) assert len(self.imshp) == len(self.imshp_logical)
......
...@@ -133,7 +133,7 @@ _as_tensor_variable = as_tensor_variable ...@@ -133,7 +133,7 @@ _as_tensor_variable = as_tensor_variable
as_tensor = as_tensor_variable as_tensor = as_tensor_variable
def constant_or_value(x, rtype, name=None, ndim=None): def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
"""Return a symbolic `Constant` with value `x` """Return a symbolic `Constant` with value `x`
:Exceptions: :Exceptions:
...@@ -141,6 +141,9 @@ def constant_or_value(x, rtype, name=None, ndim=None): ...@@ -141,6 +141,9 @@ def constant_or_value(x, rtype, name=None, ndim=None):
- `ValueError`: `x` could not be expanded to have ndim dimensions - `ValueError`: `x` could not be expanded to have ndim dimensions
""" """
if dtype is not None:
x_ = numpy.asarray(x, dtype=dtype)
else:
x_ = None x_ = None
if rtype is TensorConstant and isinstance(x, int): if rtype is TensorConstant and isinstance(x, int):
for dtype in ['int8', 'int16', 'int32', 'int64']: for dtype in ['int8', 'int16', 'int32', 'int64']:
...@@ -175,11 +178,11 @@ def constant_or_value(x, rtype, name=None, ndim=None): ...@@ -175,11 +178,11 @@ def constant_or_value(x, rtype, name=None, ndim=None):
except: except:
raise TypeError("Could not convert %s to TensorType" % x, type(x)) raise TypeError("Could not convert %s to TensorType" % x, type(x))
def constant(x, name=None, ndim=None): def constant(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim) return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim, dtype=dtype)
def value(x, name=None, ndim=None): def value(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorValue, name=name, ndim=ndim) return constant_or_value(x, rtype=TensorValue, name=name, ndim=ndim, dtype=dtype)
def _obj_is_wrappable_as_tensor(x): def _obj_is_wrappable_as_tensor(x):
try: try:
...@@ -1234,13 +1237,14 @@ pprint.assign(fill, printing.FunctionPrinter('fill')) ...@@ -1234,13 +1237,14 @@ pprint.assign(fill, printing.FunctionPrinter('fill'))
def ones_like(model): def ones_like(model):
"""WRITEME""" """WRITEME"""
#return Ones(model.type.ndim)(shape(model)) #return Ones(model.type.ndim)(shape(model))
return fill(model, 1.0) ret= fill(model, constant(1.0, dtype=model.type.dtype))
return ret
@constructor @constructor
def zeros_like(model): def zeros_like(model):
"""WRITEME""" """WRITEME"""
#return Zeros(model.type.ndim)(shape(model)) #return Zeros(model.type.ndim)(shape(model))
return fill(model, 0.0) return fill(model, constant(0.0, dtype=model.type.dtype))
class Filler(gof.Op): class Filler(gof.Op):
"""WRITEME""" """WRITEME"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论