提交 99c0fbb4 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier 提交者: --global

Ensure that condition is a theano tensor before using it

上级 ec0881a6
...@@ -68,6 +68,10 @@ class PdbBreakpoint(Op): ...@@ -68,6 +68,10 @@ class PdbBreakpoint(Op):
def make_node(self, condition, *monitored_vars): def make_node(self, condition, *monitored_vars):
# Ensure that condition is a theano tensor
if not isinstance(condition, theano.Variable):
condition = theano.tensor.as_tensor_variable(condition)
# Validate that the condition is a scalar (else it is not obvious how # Validate that the condition is a scalar (else it is not obvious how
# is should be evaluated) # is should be evaluated)
assert (condition.ndim == 0) assert (condition.ndim == 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论