提交 67a685e0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix type problem when in dimension 0.

上级 dce8a601
...@@ -191,7 +191,7 @@ class Tensor(Type): ...@@ -191,7 +191,7 @@ class Tensor(Type):
Py_XDECREF(%(name)s); Py_XDECREF(%(name)s);
} }
""" % locals() """ % locals()
def c_sync(self, name, sub): def c_sync(self, name, sub):
return """ return """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
...@@ -1026,7 +1026,7 @@ class Dot(Op): ...@@ -1026,7 +1026,7 @@ class Dot(Op):
if nx not in (1,2): raise TypeError('not matrix or vector', x) if nx not in (1,2): raise TypeError('not matrix or vector', x)
if ny not in (1,2): raise TypeError('not matrix or vector', y) if ny not in (1,2): raise TypeError('not matrix or vector', y)
if nx == 2 and ny == 2: if nx == 2 and ny == 2:
bz = [x.type.broadcastable[0], y.type.broadcastable[1]] bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
elif nx == 1 and ny == 2: elif nx == 1 and ny == 2:
...@@ -1041,7 +1041,7 @@ class Dot(Op): ...@@ -1041,7 +1041,7 @@ class Dot(Op):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, (x, y), (z, )): def perform(self, node, (x, y), (z, )):
z[0] = numpy.dot(x, y) z[0] = numpy.asarray(numpy.dot(x, y))
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
if gz.type.ndim == 0: if gz.type.ndim == 0:
return gz * y, gz * x return gz * y, gz * x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论