This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
core_op=node.op.core_op
ifnot(isinstance(core_op,ALL_INVERSE_OPS)):
returnNone
inputs=node.inputs[0]
# Check for use of pt.diag first
if(
inputs.owner
andisinstance(inputs.owner.op,AllocDiag)
andAllocDiag.is_offset_zero(inputs.owner)
):
inv_input=inputs.owner.inputs[0]
inv_val=pt.diag(1/inv_input)
return[inv_val]
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none=_find_diag_from_eye_mul(inputs)
ifinputs_or_noneisNone:
returnNone
eye_input,non_eye_inputs=inputs_or_none
# Dealing with only one other input
iflen(non_eye_inputs)!=1:
returnNone
non_eye_input=non_eye_inputs[0]
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those