added scalar upcasting to ScalarMixedOp

上级 f5c82cdc
...@@ -176,6 +176,7 @@ class Scalar(ResultBase): ...@@ -176,6 +176,7 @@ class Scalar(ResultBase):
class ScalarMixedOp(GuardedOp): class ScalarMixedOp(GuardedOp):
"""Olivier: document this stuff! -JB"""
nin = -1 nin = -1
nout = 1 nout = 1
...@@ -185,7 +186,8 @@ class ScalarMixedOp(GuardedOp): ...@@ -185,7 +186,8 @@ class ScalarMixedOp(GuardedOp):
if len(inputs) != self.nin: if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \ raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \
% (self.__class__.__name__, len(inputs), self.nin)) % (self.__class__.__name__, len(inputs), self.nin))
inputs = [as_scalar(input) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs] i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = utils.from_return_values(self.propagate_dtypes(*i_dtypes)) o_dtypes = utils.from_return_values(self.propagate_dtypes(*i_dtypes))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论