提交 dc57295f authored 作者: James Bergstra's avatar James Bergstra

Made constant canonicalization [more?] correctly handle non-scalar constants.

上级 11247235
......@@ -597,13 +597,18 @@ def local_upcast_elemwise_constant_inputs(node):
Rationale: it helps merge things like (1-x) and (1.0 - x).
"""
if len(node.outputs)>1:
return
try:
shape_i = node.env.shape_feature.shape_i
except AttributeError:
shape_i = None
if isinstance(node.op, T.Elemwise):
scalar_op = node.op.scalar_op
#print "aa", scalar_op.output_types_preference
if getattr(scalar_op,'output_types_preference',None) in (T.scal.upgrade_to_float, T.scal.upcast_out):
# this is the kind of op that we can screw with the input dtypes by upcasting
# explicitly
#print "HELLO??"
output_dtype = node.outputs[0].type.dtype
new_inputs = []
for i in node.inputs:
......@@ -615,8 +620,11 @@ def local_upcast_elemwise_constant_inputs(node):
if 0==sum((not b for b in i.broadcastable)): # I mean all() but this might work in python2.4
new_inputs.append(T.cast(cval_i, output_dtype))
else:
if shape_i is None:
return
new_inputs.append(T.alloc(T.cast(cval_i, output_dtype),
*[Shape_i(d)(i) for d in xrange(i.ndim)]))
*[shape_i(d)(i) for d in xrange(i.ndim)]))
#print >> sys.stderr, "AAA", *[Shape_i(d)(i) for d in xrange(i.ndim)]
except TypeError:
if isinstance(i, T.TensorConstant): #for the case of a non-scalar
new_inputs.append(T.cast(i, output_dtype))
......@@ -624,8 +632,15 @@ def local_upcast_elemwise_constant_inputs(node):
new_inputs.append(i)
if new_inputs != node.inputs:
return [node.op(*new_inputs)]
rval = [node.op(*new_inputs)]
if rval[0].type != node.outputs[0].type:
print >> sys.stderr, "NODE:", node
print >> sys.stderr, "NODE INPUT TYPES:", [i.type for i in node.inputs]
print >> sys.stderr, "RVAL:", rval
print >> sys.stderr, "NEW INPUT TYPES:", [i.type for i in new_inputs]
print >> sys.stderr, "RVAL INPUT TYPES:", [i.type for i in rval[0].owner.inputs]
assert rval[0].type == node.outputs[0].type, (node, rval[0])
return rval
##################
# Subtensor opts #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论