Zulip Chat Archive

Stream: lean4

Topic: Confusion about Array map linearization


Number Eighteen (Jul 31 2024 at 15:45):

I am trying to understand a remark by @Henrik Böving regarding an implementation for FloatArray.map. I have the following implementation:

@[specialize, inline]
def FloatArray.mapAux1 (A C : FloatArray) (f : Float  Float)
                      (k : Nat) (k_le_s : k  A.size) :
    FloatArray :=
  if sz_eq : k = A.size then C else
    FloatArray.mapAux1 A (C.push (f A[k])) f k.succ
      (Nat.lt_of_le_of_ne k_le_s sz_eq)

@[inline]
def FloatArray.map1 (A : FloatArray) (f : Float  Float) :
    FloatArray :=
  A.mapAux1 (FloatArray.mkEmpty A.size) f 0 (Nat.zero_le _)

which is suboptimal because it allocates a block of memory even when the input array has a single reference. Henrik suggested copying the implementation of unsafe map from Array, using uget and uset instead. So I did something similar:

@[inline, specialize]
partial def FloatArray.mapAux2 (A : FloatArray) (f : Float  Float)
                      (k : Nat) (k_le_s : k  A.size) :
    FloatArray :=
  if sz_eq : k = A.size then A else
    let g := A.uget k.toUSize sorry
    let z := f g
    let A := A.uset k.toUSize 0 sorry
    FloatArray.mapAux2 (A.uset k.toUSize z sorry) f k.succ sorry

@[inline]
def FloatArray.map2 (A : FloatArray) (f : Float  Float) :
    FloatArray :=
  A.mapAux2 f 0 (Nat.zero_le _)

Let me also give a simple helper function to initialize a big array:

@[inline]
def FloatArray.init (val : Float) (k : Nat)
                    (curr := FloatArray.mkEmpty k) :
    FloatArray :=
  match k with
  | 0 => curr
  | l + 1 => FloatArray.init val l (curr.push val)

Now I tested both map1 and map2 on the following code:

def main : IO Unit := do
  let s := 500 * 1000 * 1000
  let g := FloatArray.init 17 s
  let h := g.map2 (. + 1719191.0) |>.map2 (. + 328921.1) |>.map2 (. - 339838.1)
  IO.println <| h[0]!

Above I am using map2. Both map1 and map2 run on the above code took the same time (approx. 45sec, map2 being consistently slower) and both saturated the memory; in particular, the consecutive map calls for both map1 and map2 allocated new memory so my machine started swapping.

So am I doing something wrong?

Henrik Böving (Jul 31 2024 at 15:58):

If you use proper specialization annotations you can already get them within very close proximity:

@[specialize]
def FloatArray.mapAux1 (A C : FloatArray) (f : Float  Float)
                      (k : Nat) (k_le_s : k  A.size) :
    FloatArray :=
  if sz_eq : k = A.size then C else
    FloatArray.mapAux1 A (C.push (f A[k])) f k.succ
      (Nat.lt_of_le_of_ne k_le_s sz_eq)

@[specialize]
def FloatArray.map1 (A : FloatArray) (f : Float  Float) :
    FloatArray :=
  A.mapAux1 (FloatArray.mkEmpty A.size) f 0 (Nat.zero_le _)

@[specialize]
partial def FloatArray.mapAux2 (A : FloatArray) (f : Float  Float)
                      (k : Nat) (k_le_s : k  A.size) :
    FloatArray :=
  if sz_eq : k = A.size then A else
    let g := A.uget k.toUSize sorry
    let z := f g
    let A := A.uset k.toUSize z sorry
    FloatArray.mapAux2 A f (k + 1) sorry

@[specialize]
def FloatArray.map2 (A : FloatArray) (f : Float  Float) :
    FloatArray :=
  A.mapAux2 f 0 (Nat.zero_le _)

Looking at the IR will probably reveal what exactly is going wrong with the allocations.

François G. Dorais (Jul 31 2024 at 17:38):

You're missing the point of using uget/uset by keeping k : Nat instead of k : USize.

@[specialize]
unsafe def mapUnsafe (a : FloatArray) (f : Float  Float) : FloatArray :=
  loop a 0 a.size.toUSize -- a.usize was just added lean4#4801
where
  @[specialize] loop (a : FloatArray) (k s : USize) :=
    if k < s then
      let x := a.uget k lcProof
      let y := f x
      let a := a.uset k y lcProof
      loop a (k+1) s
    else a

Number Eighteen (Aug 01 2024 at 00:14):

François G. Dorais said:

@[specialize]
unsafe def mapUnsafe (a : FloatArray) (f : Float  Float) : FloatArray :=
  loop a 0 a.size.toUSize -- a.usize was just added lean4#4801
where
  @[specialize] loop (a : FloatArray) (k s : USize) :=
    if k < s then
      let x := a.uget k lcProof
      let y := f x
      let a := a.uset k y lcProof
      loop a (k+1) s
    else a

Is there a reason to declare the function unsafe rather use sorry on the uset/uget?

François G. Dorais (Aug 01 2024 at 03:45):

Just to use lcProof and avoid termination issues.

François G. Dorais (Aug 01 2024 at 03:51):

Once FloatArray.usize makes it to release, you could get by with partial. Maybe not even that with some USize arithmetic. But it's still inherently unsafe since a.size = a.usize.toNat is not provable, it's just an implementation limit.


Last updated: May 02 2025 at 03:31 UTC