提交 d8ea8892 authored 作者: Frederic Bastien's avatar Frederic Bastien

In GpuElemwise, don't call Elemwise.make_node, small speed up and help…

In GpuElemwise, don't call Elemwise.make_node, small speed up and help investigate. Also build directly a GpuDimShuffle instead of a DimShuffle that will be lifter later.
上级 02bde7ea
...@@ -51,13 +51,15 @@ class GpuElemwise(HideC, Elemwise): ...@@ -51,13 +51,15 @@ class GpuElemwise(HideC, Elemwise):
def make_node(self, *inputs): def make_node(self, *inputs):
ctx_name = infer_context_name(*inputs) ctx_name = infer_context_name(*inputs)
res = Elemwise.make_node(self, *inputs) inputs = [as_gpuarray_variable(i, ctx_name) for i in inputs]
outputs = [GpuArrayType(broadcastable=o.type.broadcastable, out_info = Elemwise.get_output_info(self, GpuDimShuffle, *inputs)
inputs = out_info[2]
outputs = [GpuArrayType(broadcastable=br,
context_name=ctx_name, context_name=ctx_name,
dtype=o.type.dtype)() for o in res.outputs] dtype=dtype)() for dtype, br in
zip(out_info[0], out_info[1])]
if len(outputs) > 1: if len(outputs) > 1:
raise NotImplementedError() raise NotImplementedError()
inputs = [as_gpuarray_variable(i, ctx_name) for i in inputs]
node = Apply(self, inputs, outputs) node = Apply(self, inputs, outputs)
# Try to generate the kernel to catch SupportCodeErrors # Try to generate the kernel to catch SupportCodeErrors
......
...@@ -544,13 +544,11 @@ second dimension ...@@ -544,13 +544,11 @@ second dimension
self.scalar_op.nout) self.scalar_op.nout)
self._rehash() self._rehash()
def make_node(self, *inputs): def get_output_info(self, dim_shuffle, *inputs):
""" """Return the outputs dtype and broadcastable pattern and the
If the inputs have different number of dimensions, their shape dimshuffled niputs.
is left-completed to the greatest number of dimensions with 1s
using DimShuffle.
""" """
inputs = list(map(as_tensor_variable, inputs))
shadow = self.scalar_op.make_node( shadow = self.scalar_op.make_node(
*[get_scalar_type(dtype=i.type.dtype).make_variable() *[get_scalar_type(dtype=i.type.dtype).make_variable()
for i in inputs]) for i in inputs])
...@@ -565,7 +563,7 @@ second dimension ...@@ -565,7 +563,7 @@ second dimension
args.append(input) args.append(input)
else: else:
# TODO: use LComplete instead # TODO: use LComplete instead
args.append(DimShuffle( args.append(dim_shuffle(
input.type.broadcastable, input.type.broadcastable,
['x'] * difference + list(range(length)), ['x'] * difference + list(range(length)),
inplace=False)(input)) inplace=False)(input))
...@@ -601,7 +599,18 @@ second dimension ...@@ -601,7 +599,18 @@ second dimension
raise TypeError(( raise TypeError((
"Cannot do an inplace operation on incompatible data types.", "Cannot do an inplace operation on incompatible data types.",
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern))) ([i.type.dtype for i in inputs], out_dtypes, inplace_pattern)))
assert len(out_dtypes) == len(out_broadcastables)
return out_dtypes, out_broadcastables, inputs
def make_node(self, *inputs):
"""
If the inputs have different number of dimensions, their shape
is left-completed to the greatest number of dimensions with 1s
using DimShuffle.
"""
inputs = list(map(as_tensor_variable, inputs))
out_dtypes, out_broadcastables, inputs = self.get_output_info(
DimShuffle, *inputs)
outputs = [TensorType(dtype=dtype, broadcastable=broadcastable)() outputs = [TensorType(dtype=dtype, broadcastable=broadcastable)()
for dtype, broadcastable in izip(out_dtypes, for dtype, broadcastable in izip(out_dtypes,
out_broadcastables)] out_broadcastables)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论