提交 fbb70f14 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fix things that were either screwed up by the rebase, or were new code

added in the 2 months since I originally submitted the pull request
上级 cbc51446
...@@ -262,15 +262,15 @@ class RepeatOp(theano.Op): ...@@ -262,15 +262,15 @@ class RepeatOp(theano.Op):
broadcastable=[False] broadcastable=[False]
else: else:
try: try:
const_reps = basic.get_constant_value(repeats) const_reps = basic.get_scalar_constant_value(repeats)
except basic.NotConstantError: except basic.NotScalarConstantError:
const_reps = None const_reps = None
if const_reps == 1: if const_reps == 1:
broadcastable = x.broadcastable broadcastable = x.broadcastable
else: else:
broadcastable = list(x.broadcastable) broadcastable = list(x.broadcastable)
broadcastable[self.axis] = False broadcastable[self.axis] = False
out_type = theano.tensor.TensorType(x.dtype, broadcastable) out_type = theano.tensor.TensorType(x.dtype, broadcastable)
return theano.Apply(self, [x, repeats], [out_type()]) return theano.Apply(self, [x, repeats], [out_type()])
......
...@@ -1410,7 +1410,7 @@ def _check_rows_is_arange_len_labels(rows, labels): ...@@ -1410,7 +1410,7 @@ def _check_rows_is_arange_len_labels(rows, labels):
def _is_const(z, val, approx=False): def _is_const(z, val, approx=False):
try: try:
maybe = opt.get_scalar_constant_value(z) maybe = opt.get_scalar_constant_value(z)
except TypeError: except tensor.NotScalarConstantError:
return False return False
if approx: if approx:
return numpy.allclose(maybe, val) return numpy.allclose(maybe, val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论