Zulip Chat Archive

Stream: mathlib4

Topic: termination_by in Gaussian elimination


Moritz Firsching (May 05 2023 at 12:55):

I implemented a naive Gaussian elimination and it works in the sense that I can now invert invertible matrices, but I have trouble showing termination of my recursion.

import Mathlib.Tactic
import Mathlib.Data.Matrix.Notation
import Mathlib.Algebra.Field.Basic

variable {α : Type _} [Field α] {m n : } [NeZero m] [NeZero n] [DecidableEq α]

def swap_rows (A : Matrix (Fin m) (Fin n) α) (i j : Fin m) :
    Matrix (Fin m) (Fin n) α :=
  fun r c  if r = i then A j c else if r = j then A i c else A r c

def scale_row (A : Matrix (Fin m) (Fin n) α) (i : Fin m) (k : α) :
    Matrix (Fin m) (Fin n) α :=
 fun r c  if r = i then k * A r c else A r c

def add_scaled_row (A : Matrix (Fin m) (Fin n) α) (i j : Fin m) (k : α) :
    Matrix (Fin m) (Fin n) α :=
  fun r c  if r = i then A r c + k * A j c else A r c

def gaussian_elimination_step (A : Matrix (Fin m) (Fin n) α) (i : Fin m) (j : Fin n) :
    Matrix (Fin m) (Fin n) α :=
  if _ : A i j  0 then
    let k := (A i j)⁻¹
    let A' := scale_row A i k
    List.foldl (fun (A'' : Matrix (Fin m) (Fin n) α) (r : Fin m) 
                             if r  i
                             then add_scaled_row A'' r i (-A'' r j)
                             else A'') A' (List.finRange m)
  else A

def step : Matrix (Fin m) (Fin n) α      Matrix (Fin m) (Fin n) α
  | A', i, j =>
    if h : i < m  j < n then
      let rows_to_search := List.drop i (List.finRange m)
      let pivot := List.find? (fun r => A' r j, (h.right)⟩ != 0) rows_to_search
      match pivot with
      | none => step A' i (j + 1)
      | some r =>
        let A'' := swap_rows A' i, h.left r
        let A''' := gaussian_elimination_step A'' i, h.left j, h.right
        step A''' (i + 1) (j + 1)
    else A'
termination_by _ i j => (m - i) * (n - j)
decreasing_by sorry

def gaussian_elimination (A : Matrix (Fin m) (Fin n) α) : Matrix (Fin m) (Fin n) α :=
  step A 0 0

def augment_with_identity (A : Matrix (Fin n) (Fin n) α) : Matrix (Fin n) (Fin (n + n)) α :=
  fun i j  if j.val < n then A i j else (1 : Matrix (Fin n) (Fin n) α) i (j - n)

def extract_inverse (A : Matrix (Fin n) (Fin (n + n)) α) : Matrix (Fin n) (Fin n) α :=
  fun i j  by
    have h : n + j < n + n := by simp only [add_lt_add_iff_left, Fin.is_lt]
    exact A i n  + j, h

def matrix_inverse (A : Matrix (Fin n) (Fin n) α) : Matrix (Fin n) (Fin n) α :=
  let augmented := augment_with_identity A
  extract_inverse <| gaussian_elimination augmented


-- invertible example
def M := !![(2 : ), -1/2; -3, 1]
def N := matrix_inverse M
#eval M*N == 1
#eval N*M == 1

-- non-invertible example
def M' := !![(2 : ), -1/2; 6, -3/2]
def N' := matrix_inverse M'
#eval N'
#eval M'*N' == 1

How do I show termination? I'm not even sure that the part termination_by _ i j => (m - i) * (n - j) is a right start. In words, I'd argue that the recursion ends because i or j increases in each step and as soon as they are are at least m or n respectively we are finished.

Scott Morrison (May 05 2023 at 22:49):

A brief answer: if you want to do termination_by some complicated expression, you will need to have the fact that that quantity has decreased into the context of each recursive call. That is, the fact that it is decreasing should be done inline with the code. This is contrary to the usual practice of proving things after defining them, unfortunately, and so for sufficiently complicated situations requires rethinking just using termination_by.

Moritz Firsching (May 06 2023 at 11:48):

Thanks! That helped a lot, I'm able to show termination now (not at all golfed yet...):

def step : Matrix (Fin m) (Fin n) α      Matrix (Fin m) (Fin n) α
  | A', i, j =>
    if h : i < m  j < n then
      have hn : n - (j + 1) < n - j := by
        refine' Nat.sub_lt_sub_left  _ _
        linarith
        simp only [lt_add_iff_pos_right]
      have hm : m - (i + 1) < m - i := by
        refine' Nat.sub_lt_sub_left  _ _
        linarith
        simp only [lt_add_iff_pos_right]
      let rows_to_search := List.drop i (List.finRange m)
      let pivot := List.find? (fun r => A' r j, (h.right)⟩ != 0) rows_to_search
      match pivot with
      | none =>
        step A' i (j + 1)
      | some r =>
        let A'' := swap_rows A' i, h.left r
        let A''' := gaussian_elimination_step A'' i, h.left j, h.right
        have : m - (i + 1) + (n - (j + 1)) < m - i + (n - j) := by linarith
        step A''' (i + 1) (j + 1)
    else A'
termination_by step _ i j => (m - i) + (n - j)

Kyle Miller (May 06 2023 at 12:12):

If you're planning on running this algorithm in practice, watch out since it will have worse asymptotic performance than you might expect. This is because the matrices are functions, so, for example swap_rows is creating a new closure that calculates entries from scratch every time. Using arrays of arrays (or writing a function that caches all the values of a matrix as an array of arrays then creates a new Matrix that reads from these arrays) would be a way to fix this.

Moritz Firsching (May 07 2023 at 20:15):

Thanks, @Kyle Miller that is certainly something to keep in mind. I was not really going for performance with this, but basically just playing around, because I didn't find Gaussian elimination in mathlib already.

Yakov Pechersky (May 07 2023 at 20:25):

iirc Sebastien added something related to elementary matrices a while back

Eric Wieser (May 07 2023 at 21:45):

docs#matrix.transvection is in the same file as what Yakov Pechersky is referring to


Last updated: Dec 20 2023 at 11:08 UTC