[InstCombine] Missed optimization: Fold `(sext(a) & sext(c1)) == c2` to `(a & c1) == c2` or `0`
Alive2 proof: https://alive2.llvm.org/ce/z/44YNAL (contains generalized pattern)
Motivating example
define i1 @src(i8 %a) {
entry:
%conv = sext i8 %a to i32
%0 = and i32 %conv, -2147483647
%cmp = icmp eq i32 %0, 1
ret i1 %cmp
}
can be folded to:
define i1 @tgt(i8 %a) {
entry:
%and = and i8 %a, -127
%cmp = icmp eq i8 %and, 1
ret i1 %cmp
}
Real-world motivation
This snippet of IR is derived from abseil-cpp/float_conversion.cc@FloatToBuffer (after O3 pipeline). The example above is a reduced version. If you're interested in the original suboptimal IR and optimal IR, email me please.
Let me know if you can confirm that it's an optimization opportunity, thanks.
@XChy @EugeneZelenko Can I work on this ticket?
@Abhinkop Are you still working on this? If not, can I give it a try?
@leewei05 , it seems that @Abhinkop is not working on this, and feel free to fix it. Please take a look at guide before submitting your patch.
@XChy @dtcxzyw Hello! I have two questions regarding to this issue.
First, in the alive2 proof https://alive2.llvm.org/ce/z/3Wjdak. src2 is being optimized to
define i1 @src2(i8 %a, i8 %b, i32 %c) local_unnamed_addr #1 {
%as = icmp ult i32 %c, 128
tail call void @llvm.assume(i1 %as)
%0 = and i8 %b, %a
%1 = sext i8 %0 to i32
%cmp = icmp eq i32 %c, %1
ret i1 %cmp
}
Does this mean that I should match this pattern instead?
Second, I was wondering how the first example can be matched with the general pattern ((sext(a) & sext(c1)) == c2).
Isn't -2147483647 represented as 0x80000001 in 32 bits? Is there a c1 that can be signed extend to 32 bits as -2147483647?
define i1 @src(i8 %a) {
entry:
%conv = sext i8 %a to i32
%0 = and i32 %conv, -2147483647
%cmp = icmp eq i32 %0, 1
ret i1 %cmp
}
I'm trying to match a more generic pattern since my initial PR is too specific. Thank you in advanced!
@leewei05 I'm sorry for my fault in this old issue. You're right, (sext(a) & sext(c1)) == c2 doesn't match the motivating example.
A correct and more general pattern should be sext(a) & c1 == c2 --> a & c3 == trunc(c2), where c3 is another constant transformed from c1 and c2 must be positive.
The method of constructing proper c3 is as below:
Assuming the bitwidth of a is 3 and the bitwidth of c1 is 5, we look at their bits:
a = S A1 A2
sext(a) = S S S A1 A2
c1 = X1 X2 X3 B1 B2
c2 = 0 0 0 C1 C2
We truncate 5-bit c1 into 3-bit c3 = (X1 | X2 | X3) B1 B2, which makes the highest bit of a & c3 equal to (S & X1 == 0) & (S & X2 == 0) & (S & X3 == 0).
It's a bit hard to formalize it in alive2. I'll give it a try latter.
Formal alive2 proof: https://alive2.llvm.org/ce/z/iQUQfE
@XChy Thanks for updating the proof! I've updated my PR based on your proof!