提交 dc57295f authored 作者: James Bergstra's avatar James Bergstra

Made constant canonicalization [more?] correctly handle non-scalar constants.

上级 11247235
...@@ -597,13 +597,18 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -597,13 +597,18 @@ def local_upcast_elemwise_constant_inputs(node):
Rationale: it helps merge things like (1-x) and (1.0 - x). Rationale: it helps merge things like (1-x) and (1.0 - x).
""" """
if len(node.outputs)>1:
return
try:
shape_i = node.env.shape_feature.shape_i
except AttributeError:
shape_i = None
if isinstance(node.op, T.Elemwise): if isinstance(node.op, T.Elemwise):
scalar_op = node.op.scalar_op scalar_op = node.op.scalar_op
#print "aa", scalar_op.output_types_preference #print "aa", scalar_op.output_types_preference
if getattr(scalar_op,'output_types_preference',None) in (T.scal.upgrade_to_float, T.scal.upcast_out): if getattr(scalar_op,'output_types_preference',None) in (T.scal.upgrade_to_float, T.scal.upcast_out):
# this is the kind of op that we can screw with the input dtypes by upcasting # this is the kind of op that we can screw with the input dtypes by upcasting
# explicitly # explicitly
#print "HELLO??"
output_dtype = node.outputs[0].type.dtype output_dtype = node.outputs[0].type.dtype
new_inputs = [] new_inputs = []
for i in node.inputs: for i in node.inputs:
...@@ -615,8 +620,11 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -615,8 +620,11 @@ def local_upcast_elemwise_constant_inputs(node):
if 0==sum((not b for b in i.broadcastable)): # I mean all() but this might work in python2.4 if 0==sum((not b for b in i.broadcastable)): # I mean all() but this might work in python2.4
new_inputs.append(T.cast(cval_i, output_dtype)) new_inputs.append(T.cast(cval_i, output_dtype))
else: else:
if shape_i is None:
return
new_inputs.append(T.alloc(T.cast(cval_i, output_dtype), new_inputs.append(T.alloc(T.cast(cval_i, output_dtype),
*[Shape_i(d)(i) for d in xrange(i.ndim)])) *[shape_i(d)(i) for d in xrange(i.ndim)]))
#print >> sys.stderr, "AAA", *[Shape_i(d)(i) for d in xrange(i.ndim)]
except TypeError: except TypeError:
if isinstance(i, T.TensorConstant): #for the case of a non-scalar if isinstance(i, T.TensorConstant): #for the case of a non-scalar
new_inputs.append(T.cast(i, output_dtype)) new_inputs.append(T.cast(i, output_dtype))
...@@ -624,8 +632,15 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -624,8 +632,15 @@ def local_upcast_elemwise_constant_inputs(node):
new_inputs.append(i) new_inputs.append(i)
if new_inputs != node.inputs: if new_inputs != node.inputs:
return [node.op(*new_inputs)] rval = [node.op(*new_inputs)]
if rval[0].type != node.outputs[0].type:
print >> sys.stderr, "NODE:", node
print >> sys.stderr, "NODE INPUT TYPES:", [i.type for i in node.inputs]
print >> sys.stderr, "RVAL:", rval
print >> sys.stderr, "NEW INPUT TYPES:", [i.type for i in new_inputs]
print >> sys.stderr, "RVAL INPUT TYPES:", [i.type for i in rval[0].owner.inputs]
assert rval[0].type == node.outputs[0].type, (node, rval[0])
return rval
################## ##################
# Subtensor opts # # Subtensor opts #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论