Zulip Chat Archive

Stream: lean4

Topic: Efficient lookup functions


Andrej Bauer (May 04 2023 at 15:42):

Is the following going to be compiled efficiently (i.e., better than linear in N)? If not, how to compile it efficiently? What if I have the same construct but using match ... with ...?

def cow : Nat  Nat
  | 0 => ...
  | 1 => ...
  | 2 =>  ...
  ...
  | N => ...
  | _ => ... -- fallback case

For instance, would it be faster to use red-black trees (or some other kind of search trees) and lookup up values in the trees?

Eric Wieser (May 04 2023 at 15:43):

I would guess that match is identical to a direct pattern match

Kyle Miller (May 04 2023 at 15:50):

I did a little test to see what kind of C code it produces, since last I checked it was a linear search but I wanted to be sure.

Here's my foo2.lean:

def cow : Nat  Nat
  | 0 => 100
  | 1 => 111
  | 2 => 102
  | 3 => 114
  | 5 => 106
  | 6 => 117
  | 7 => 108
  | _ => 37

then in the Scratch project I did lake build Scratch.foo2:c. Then build/ir/Scratch/foo2.c shows that it's a linear search:

LEAN_EXPORT lean_object* l_cow(lean_object* x_1) {
_start:
{
lean_object* x_2; uint8_t x_3;
x_2 = lean_unsigned_to_nat(0u);
x_3 = lean_nat_dec_eq(x_1, x_2);
if (x_3 == 0)
{
lean_object* x_4; uint8_t x_5;
x_4 = lean_unsigned_to_nat(1u);
x_5 = lean_nat_dec_eq(x_1, x_4);
if (x_5 == 0)
{
lean_object* x_6; uint8_t x_7;
x_6 = lean_unsigned_to_nat(2u);
x_7 = lean_nat_dec_eq(x_1, x_6);
if (x_7 == 0)
{
lean_object* x_8; uint8_t x_9;
x_8 = lean_unsigned_to_nat(3u);
x_9 = lean_nat_dec_eq(x_1, x_8);
if (x_9 == 0)
{
lean_object* x_10; uint8_t x_11;
x_10 = lean_unsigned_to_nat(5u);
x_11 = lean_nat_dec_eq(x_1, x_10);
if (x_11 == 0)
{
lean_object* x_12; uint8_t x_13;
x_12 = lean_unsigned_to_nat(6u);
x_13 = lean_nat_dec_eq(x_1, x_12);
if (x_13 == 0)
{
lean_object* x_14; uint8_t x_15;
x_14 = lean_unsigned_to_nat(7u);
x_15 = lean_nat_dec_eq(x_1, x_14);
if (x_15 == 0)
{
lean_object* x_16;
x_16 = lean_unsigned_to_nat(37u);
return x_16;
}
else
{
lean_object* x_17;
x_17 = lean_unsigned_to_nat(108u);
return x_17;
}
}
else
{
lean_object* x_18;
x_18 = lean_unsigned_to_nat(117u);
return x_18;
}
}
else
{
lean_object* x_19;
x_19 = lean_unsigned_to_nat(106u);
return x_19;
}
}
else
{
lean_object* x_20;
x_20 = lean_unsigned_to_nat(114u);
return x_20;
}
}
else
{
lean_object* x_21;
x_21 = lean_unsigned_to_nat(102u);
return x_21;
}
}
else
{
lean_object* x_22;
x_22 = lean_unsigned_to_nat(111u);
return x_22;
}
}
else
{
lean_object* x_23;
x_23 = lean_unsigned_to_nat(100u);
return x_23;
}
}
}

Kyle Miller (May 04 2023 at 15:56):

Maybe it's helpful to see mkEnumOfNat, which is used for deriving DecidableEq when you have an inductive type with only 0-ary constructors. It creates a definition that gives a constructor for a given index that runs in log time. Example:

inductive Animals
  | cow | horse | cat | dog | moose | shoop
  deriving DecidableEq

#print Animals.ofNat
/-
def Animals.ofNat : Nat → Animals :=
fun n =>
  bif Nat.ble 3 n then bif Nat.ble 4 n then bif Nat.beq n 4 then Animals.moose else Animals.shoop else Animals.dog
  else bif Nat.ble 1 n then bif Nat.beq n 1 then Animals.horse else Animals.cat else Animals.cow
-/

Andrej Bauer (May 04 2023 at 16:05):

Thanks. I think we'll stick to search trees for now, those ought to work efficiently.

Henrik Böving (May 04 2023 at 16:05):

Note that, even if we generate C code in this way the c compiler might still optimize it into the log variant so I would say the best way might be to benchmark?

Kyle Miller (May 04 2023 at 16:46):

In this case it's not so bad to look at the generated assembly code. GCC compiles this into the following algorithm:

  1. check if it's not a scalar (i.e., if it's not an integer embedded in a pointer; Lean uses an encoding where pointers with a 1 in the least-significant bit is an integer), and in this case fall back on doing a linear search.
  2. otherwise, compute a value from the scalar, use it to index into a table of pointers, then jump to a piece of code that returns the correct value.

The jump table is somewhat wonky, and for some reason one of the cases is used both in the fallback and for one of the scalar values, so it has an extra test and a cmov in there.

Assuming all small integers are encoded as scalars (they should be), then the algorithm is O(1), assuming GCC is able to do this jump table transformation even if there are many more cases.

l_cow:
    pushq   %rbx
    movq    %rdi, %rbx
    testb   $1, %bl
    je  .LBB0_1
    leaq    -1(%rbx), %rcx
    cmpq    $12, %rcx
    ja  .LBB0_10
    movl    $201, %eax
    leaq    .LJTI0_0(%rip), %rdx
    movslq  (%rdx,%rcx,4), %rcx
    addq    %rdx, %rcx
    jmpq    *%rcx
.LBB0_12:
    movl    $223, %eax
    popq    %rbx
    retq
.LBB0_10:
    cmpq    $15, %rbx
    sete    %al
.LBB0_11:
    testb   %al, %al
    movl    $217, %ecx
    movl    $75, %eax
    cmovneq %rcx, %rax
    popq    %rbx
    retq
.LBB0_13:
    movl    $205, %eax
    popq    %rbx
    retq
.LBB0_14:
    movl    $229, %eax
    popq    %rbx
    retq
.LBB0_15:
    movl    $213, %eax
    popq    %rbx
    retq
.LBB0_16:
    movl    $235, %eax
.LBB0_17:
    popq    %rbx
    retq
.LBB0_1:
    movl    $1, %esi
    movq    %rbx, %rdi
    callq   lean_nat_big_eq@PLT
    movl    %eax, %ecx
    movl    $201, %eax
    testb   %cl, %cl
    jne .LBB0_17
    movl    $3, %esi
    movq    %rbx, %rdi
    callq   lean_nat_big_eq@PLT
    movl    %eax, %ecx
    movl    $223, %eax
    testb   %cl, %cl
    jne .LBB0_17
    movl    $5, %esi
    movq    %rbx, %rdi
    callq   lean_nat_big_eq@PLT
    movl    %eax, %ecx
    movl    $205, %eax
    testb   %cl, %cl
    jne .LBB0_17
    movl    $7, %esi
    movq    %rbx, %rdi
    callq   lean_nat_big_eq@PLT
    movl    %eax, %ecx
    movl    $229, %eax
    testb   %cl, %cl
    jne .LBB0_17
    movl    $11, %esi
    movq    %rbx, %rdi
    callq   lean_nat_big_eq@PLT
    movl    %eax, %ecx
    movl    $213, %eax
    testb   %cl, %cl
    jne .LBB0_17
    movl    $13, %esi
    movq    %rbx, %rdi
    callq   lean_nat_big_eq@PLT
    movl    %eax, %ecx
    movl    $235, %eax
    testb   %cl, %cl
    jne .LBB0_17
    movl    $15, %esi
    movq    %rbx, %rdi
    callq   lean_nat_big_eq@PLT
    jmp .LBB0_11
.LJTI0_0:
    .long   .LBB0_17-.LJTI0_0
    .long   .LBB0_10-.LJTI0_0
    .long   .LBB0_12-.LJTI0_0
    .long   .LBB0_10-.LJTI0_0
    .long   .LBB0_13-.LJTI0_0
    .long   .LBB0_10-.LJTI0_0
    .long   .LBB0_14-.LJTI0_0
    .long   .LBB0_10-.LJTI0_0
    .long   .LBB0_10-.LJTI0_0
    .long   .LBB0_10-.LJTI0_0
    .long   .LBB0_15-.LJTI0_0
    .long   .LBB0_10-.LJTI0_0
    .long   .LBB0_16-.LJTI0_0

Mario Carneiro (May 04 2023 at 17:11):

Assuming all small integers are encoded as scalars (they should be),

This is a requirement of the encoding. "small big integers" will cause arithmetic operations to produce incorrect results.


Last updated: Dec 20 2023 at 11:08 UTC