提交 d1cc7c4b authored 作者: Olivier Breuleux's avatar Olivier Breuleux

tested all elemwise ops and fixed bugs in scalar implementations

上级 bf3df169
差异被折叠。
...@@ -149,13 +149,13 @@ class Broadcast(Op, Destroyer): ...@@ -149,13 +149,13 @@ class Broadcast(Op, Destroyer):
if ib and not ob: if ib and not ob:
raise ValueError("Operation cannot be done inplace on an input with broadcasted dimensions.") raise ValueError("Operation cannot be done inplace on an input with broadcasted dimensions.")
upcasted = upcast(*[input.dtype for input in inputs]) out_dtypes = [t.dtype for t in self.shadow.outputs]
def get_dtype(i): def get_dtype(i):
input_idx = inplace_pattern.get(i, None) input_idx = inplace_pattern.get(i, None)
if input_idx is not None: if input_idx is not None:
return inputs[input_idx].dtype return inputs[input_idx].dtype
else: else:
return upcasted return out_dtypes[i]
out_dtypes = map(get_dtype, xrange(self.nout)) out_dtypes = map(get_dtype, xrange(self.nout))
self.inputs = inputs self.inputs = inputs
self.outputs = [Tensor(dtype = dtype, broadcastable = broadcastable) for dtype, broadcastable in zip(out_dtypes, out_broadcastables)] self.outputs = [Tensor(dtype = dtype, broadcastable = broadcastable) for dtype, broadcastable in zip(out_dtypes, out_broadcastables)]
...@@ -201,6 +201,9 @@ class Broadcast(Op, Destroyer): ...@@ -201,6 +201,9 @@ class Broadcast(Op, Destroyer):
return bcasted return bcasted
ret = [] ret = []
for scalar_igrad, input in zip(scalar_igrads, inputs): for scalar_igrad, input in zip(scalar_igrads, inputs):
if scalar_igrad is None:
ret.append(None)
continue
r = transform(scalar_igrad) r = transform(scalar_igrad)
to_sum = [i for i, bcast in enumerate(input.broadcastable) if bcast] to_sum = [i for i, bcast in enumerate(input.broadcastable) if bcast]
if to_sum: if to_sum:
......
...@@ -282,6 +282,8 @@ class Div(BinaryScalarOp): ...@@ -282,6 +282,8 @@ class Div(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x / y return x / y
def c_code(self, (x, y), (z, ), sub): def c_code(self, (x, y), (z, ), sub):
if 'int' in self.inputs[0].dtype and 'int' in self.inputs[1].dtype:
raise NotImplementedError("For integer arguments the behavior of division in C and in Python differ when the quotient is negative (to implement).")
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz / y, -(gz * x) / (y * y) return gz / y, -(gz * x) / (y * y)
...@@ -346,13 +348,13 @@ class Abs(UnaryScalarOp): ...@@ -346,13 +348,13 @@ class Abs(UnaryScalarOp):
class Sgn(UnaryScalarOp): class Sgn(UnaryScalarOp):
def impl(self, x): def impl(self, x):
#casting to output type is handled by filter #casting to output type is handled by filter
return 1.0 if x >= 0 else -1.0 return numpy.sign(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return None, return None,
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
#casting is done by compiler #casting is done by compiler
#TODO: use copysign #TODO: use copysign
return "%(z)s = (%(x)s >= 0) ? 1.0 : -1.0;" % locals() return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals()
class Inv(FloatUnaryScalarOp): class Inv(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
...@@ -406,7 +408,7 @@ class Cos(FloatUnaryScalarOp): ...@@ -406,7 +408,7 @@ class Cos(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.cos(x) return math.cos(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * sin(x), return -gz * sin(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = cos(%(x)s);" % locals() return "%(z)s = cos(%(x)s);" % locals()
...@@ -414,7 +416,7 @@ class Sin(FloatUnaryScalarOp): ...@@ -414,7 +416,7 @@ class Sin(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sin(x) return math.sin(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz * cos(x), return gz * cos(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals() return "%(z)s = sin(%(x)s);" % locals()
...@@ -440,13 +442,13 @@ class Sinh(FloatUnaryScalarOp): ...@@ -440,13 +442,13 @@ class Sinh(FloatUnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
raise NotImplementedError() raise NotImplementedError()
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals() return "%(z)s = sinh(%(x)s);" % locals()
class Tanh(FloatUnaryScalarOp): class Tanh(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.tanh(x) return math.tanh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * (1 - tanh(x))**2 return gz * (1 - tanh(x)**2),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = tanh(%(x)s);" % locals() return "%(z)s = tanh(%(x)s);" % locals()
......
...@@ -84,7 +84,7 @@ def astensor(data, broadcastable=None, name=None): ...@@ -84,7 +84,7 @@ def astensor(data, broadcastable=None, name=None):
raise ValueError("Cannot rename an existing Tensor.") raise ValueError("Cannot rename an existing Tensor.")
return data return data
elif isinstance(data, Result): elif isinstance(data, Result):
raise TypeError("Cannot make a Tensor out of a non-Tensor result.") raise TypeError("Cannot make a Tensor out of a non-Tensor result.", data)
if data is None and broadcastable is None: if data is None and broadcastable is None:
raise TypeError("Cannot make a Tensor out of None.") raise TypeError("Cannot make a Tensor out of None.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论