提交 5992e548 authored 作者: nouiz's avatar nouiz

Merge pull request #178 from nouiz/important_fix2

Finish the fix to the change to ShapeFeature.
...@@ -815,10 +815,13 @@ nd_collapse_[i]=0; ...@@ -815,10 +815,13 @@ nd_collapse_[i]=0;
""" """
kernels = "".join( kernels = "".join(
[self.c_src_kernel(node, nodename, x) for x in xrange(1, nd + 1)] [self.c_src_kernel(node, nodename, x) for x in xrange(1, nd + 1)]
+ [self.c_src_kernel_Ccontiguous(node, nodename)], + [self.c_src_kernel_Ccontiguous(node, nodename)]
+ [self.c_src_callkernel(node, nodename)]) + [self.c_src_callkernel(node, nodename)])
return defines + kernels return defines + kernels
def c_support_code(self):
raise gof.utils.MethodNotDefined()
def c_code(self, node, nodename, inputs, outputs, sub): def c_code(self, node, nodename, inputs, outputs, sub):
d = dict(sub) d = dict(sub)
nd = node.outputs[0].type.ndim nd = node.outputs[0].type.ndim
......
...@@ -77,7 +77,7 @@ class InputToGpuOptimizer(Optimizer): ...@@ -77,7 +77,7 @@ class InputToGpuOptimizer(Optimizer):
if new_input.type==input.type: if new_input.type==input.type:
env.replace_validate(input, new_input, "InputToGpuOptimizer") env.replace_validate(input, new_input, "InputToGpuOptimizer")
except Exception, e: except TypeError, e:
#as we currently only support float32, this can fail. #as we currently only support float32, this can fail.
#Using try except make that we won't need #Using try except make that we won't need
pass pass
...@@ -243,7 +243,12 @@ def local_gpu_dot_to_dot22(node): ...@@ -243,7 +243,12 @@ def local_gpu_dot_to_dot22(node):
A more suitable solution would be to use the right cublas call A more suitable solution would be to use the right cublas call
""" """
# In case the got do input upcast, we much check that we can
# make it run on the gpu.
if node.op == gpu_from_host: if node.op == gpu_from_host:
if node.outputs[0].type.dtype != 'float32':
return False
host_input = node.inputs[0] host_input = node.inputs[0]
if host_input.owner and host_input.owner.op == tensor.basic.dot: if host_input.owner and host_input.owner.op == tensor.basic.dot:
x, y = host_input.owner.inputs x, y = host_input.owner.inputs
...@@ -264,6 +269,8 @@ def local_gpu_dot_to_dot22(node): ...@@ -264,6 +269,8 @@ def local_gpu_dot_to_dot22(node):
return [GpuReshape(1)(gpu_dot22(gpu_x, gpu_y), shape_out)] return [GpuReshape(1)(gpu_dot22(gpu_x, gpu_y), shape_out)]
if node.op == tensor.basic.dot: if node.op == tensor.basic.dot:
if node.outputs[0].type.dtype != 'float32':
return False
if numpy.any([(i.owner and i.owner.op == host_from_gpu) for i in node.inputs]): if numpy.any([(i.owner and i.owner.op == host_from_gpu) for i in node.inputs]):
x, y = node.inputs x, y = node.inputs
if _is_real_vector(x) and _is_real_matrix(y): if _is_real_vector(x) and _is_real_matrix(y):
...@@ -1267,7 +1274,7 @@ def gpu_scan_make_inplace(node): ...@@ -1267,7 +1274,7 @@ def gpu_scan_make_inplace(node):
n_outs = len(ls) n_outs = len(ls)
for idx in xrange(n_outs): for idx in xrange(n_outs):
if ls[idx] in ls[:idx]: if ls[idx] in ls[:idx]:
ls[idx] = deep_copy_op(ls[idx]) ls[idx] = compile.function_module.deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end inputs = ls_begin + ls + ls_end
......
...@@ -922,9 +922,17 @@ class ShapeFeature(object): ...@@ -922,9 +922,17 @@ class ShapeFeature(object):
self.set_shape(r, s) self.set_shape(r, s)
def on_change_input(self, env, node, i, r, new_r): def on_change_input(self, env, node, i, r, new_r):
if new_r not in self.shape_of:
# It happen that the env didn't called on_import for some
# new_r. This happen when new_r don't have an
# owner(i.e. it is a constant or an input of the graph)
# update_shape suppose that r and new_r are in shape_of.
self.init_r(new_r)
# This tells us that r and new_r must have the same shape # This tells us that r and new_r must have the same shape
# if we didn't know that the shapes are related, now we do. # if we didn't know that the shapes are related, now we do.
self.update_shape(new_r, r) # We should give priority to new_r, so we put it last.
self.update_shape(r, new_r)
# change_input happens in two cases: # change_input happens in two cases:
# 1) we are trying to get rid of r, or # 1) we are trying to get rid of r, or
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论