提交 105bcc8c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Change input validation asserts in TensorFromScalar/ScalarFromTensor to exceptions

上级 a91644f3
......@@ -556,7 +556,9 @@ class TensorFromScalar(Op):
__props__ = ()
def make_node(self, s):
assert isinstance(s.type, aes.Scalar)
if not isinstance(s.type, aes.Scalar):
raise TypeError("Input must be a `Scalar` `Type`")
return Apply(self, [s], [tensor(dtype=s.type.dtype, broadcastable=())])
def perform(self, node, inp, out_):
......@@ -592,8 +594,9 @@ class ScalarFromTensor(COp):
__props__ = ()
def make_node(self, t):
assert isinstance(t.type, TensorType)
assert t.type.broadcastable == ()
if not isinstance(t.type, TensorType) or t.type.ndim > 0:
raise TypeError("Input must be a scalar `TensorType`")
return Apply(
self, [t], [aes.get_scalar_type(dtype=t.type.dtype).make_variable()]
)
......
......@@ -1844,6 +1844,9 @@ def test_TensorFromScalar():
g = grad(t, s)
assert eval_outputs([g]) == 0.0
with pytest.raises(TypeError):
tensor_from_scalar(vector())
def test_ScalarFromTensor():
tc = constant(56) # aes.constant(56)
......@@ -1872,6 +1875,9 @@ def test_ScalarFromTensor():
assert isinstance(v, np.int64)
assert v.shape == ()
with pytest.raises(TypeError):
scalar_from_tensor(vector())
class TestOpCache:
def test_basic(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论