提交 facb53ae authored 作者: Melanie Ducoffe's avatar Melanie Ducoffe

adds dtype in props of AllocEmpty and remove debug print

上级 4ec148bd
...@@ -5473,7 +5473,7 @@ class Choose(Op): ...@@ -5473,7 +5473,7 @@ class Choose(Op):
class AllocEmpty(gof.Op): class AllocEmpty(gof.Op):
"""Implement Alloc on the gpu, but without initializing memory.""" """Implement Alloc on the gpu, but without initializing memory."""
__props__ = () __props__ = ("dtype",)
# specify the type of the data # specify the type of the data
def __init__(self, dtype): def __init__(self, dtype):
...@@ -5550,8 +5550,7 @@ class AllocEmpty(gof.Op): ...@@ -5550,8 +5550,7 @@ class AllocEmpty(gof.Op):
return [node.inputs] return [node.inputs]
def c_code_cache_version(self): def c_code_cache_version(self):
return None return (2,)
#return (1,)
def do_constant_folding(self, node): def do_constant_folding(self, node):
return False return False
...@@ -1830,21 +1830,18 @@ def local_dot22_to_ger_or_gemv(node): ...@@ -1830,21 +1830,18 @@ def local_dot22_to_ger_or_gemv(node):
# x and y are both vectors so this qualifies for a sdot / ddot # x and y are both vectors so this qualifies for a sdot / ddot
# TODO: Theano doesn't have a sdot, but gemv is better than _dot22 # TODO: Theano doesn't have a sdot, but gemv is better than _dot22
xv = x.dimshuffle(1) xv = x.dimshuffle(1)
#zeros = T.zeros([1], x.dtype)
zeros = T.AllocEmpty(x.dtype)(1) zeros = T.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
return [rval.dimshuffle('x', 0)] return [rval.dimshuffle('x', 0)]
if xb[0] and not yb[0] and not yb[1]: if xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv # x is vector, y is matrix so try gemv
xv = x.dimshuffle(1) xv = x.dimshuffle(1)
#zeros = T.zeros([y.shape[1]], x.dtype)
zeros = T.AllocEmpty(x.dtype)(y.shape[1]) zeros = T.AllocEmpty(x.dtype)(y.shape[1])
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
return [rval.dimshuffle('x', 0)] return [rval.dimshuffle('x', 0)]
if not xb[0] and not xb[1] and yb[1]: if not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv # x is matrix, y is vector, try gemv
yv = y.dimshuffle(0) yv = y.dimshuffle(0)
#zeros = T.zeros([x.shape[0]], dtype=x.dtype)
zeros = T.AllocEmpty(x.dtype)(x.shape[0]) zeros = T.AllocEmpty(x.dtype)(x.shape[0])
rval = gemv_no_inplace(zeros, one, x, yv, zero) rval = gemv_no_inplace(zeros, one, x, yv, zero)
return [rval.dimshuffle(0, 'x')] return [rval.dimshuffle(0, 'x')]
......
...@@ -930,6 +930,7 @@ def test_dot22scalar(): ...@@ -930,6 +930,7 @@ def test_dot22scalar():
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 5) check_dot22scalar(f, 5)
#print (av.dtype, bv.dtype, cv.dtype)
f(av, bv, cv) f(av, bv, cv)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论