Zulip Chat Archive
Stream: mathlib4
Topic: termination_by in Gaussian elimination
Moritz Firsching (May 05 2023 at 12:55):
I implemented a naive Gaussian elimination and it works in the sense that I can now invert invertible matrices, but I have trouble showing termination of my recursion.
import Mathlib.Tactic
import Mathlib.Data.Matrix.Notation
import Mathlib.Algebra.Field.Basic
variable {α : Type _} [Field α] {m n : ℕ} [NeZero m] [NeZero n] [DecidableEq α]
def swap_rows (A : Matrix (Fin m) (Fin n) α) (i j : Fin m) :
Matrix (Fin m) (Fin n) α :=
fun r c ↦ if r = i then A j c else if r = j then A i c else A r c
def scale_row (A : Matrix (Fin m) (Fin n) α) (i : Fin m) (k : α) :
Matrix (Fin m) (Fin n) α :=
fun r c ↦ if r = i then k * A r c else A r c
def add_scaled_row (A : Matrix (Fin m) (Fin n) α) (i j : Fin m) (k : α) :
Matrix (Fin m) (Fin n) α :=
fun r c ↦ if r = i then A r c + k * A j c else A r c
def gaussian_elimination_step (A : Matrix (Fin m) (Fin n) α) (i : Fin m) (j : Fin n) :
Matrix (Fin m) (Fin n) α :=
if _ : A i j ≠ 0 then
let k := (A i j)⁻¹
let A' := scale_row A i k
List.foldl (fun (A'' : Matrix (Fin m) (Fin n) α) (r : Fin m) ↦
if r ≠ i
then add_scaled_row A'' r i (-A'' r j)
else A'') A' (List.finRange m)
else A
def step : Matrix (Fin m) (Fin n) α → ℕ → ℕ → Matrix (Fin m) (Fin n) α
| A', i, j =>
if h : i < m ∧ j < n then
let rows_to_search := List.drop i (List.finRange m)
let pivot := List.find? (fun r => A' r ⟨j, (h.right)⟩ != 0) rows_to_search
match pivot with
| none => step A' i (j + 1)
| some r =>
let A'' := swap_rows A' ⟨i, h.left⟩ r
let A''' := gaussian_elimination_step A'' ⟨i, h.left⟩ ⟨j, h.right⟩
step A''' (i + 1) (j + 1)
else A'
termination_by _ i j => (m - i) * (n - j)
decreasing_by sorry
def gaussian_elimination (A : Matrix (Fin m) (Fin n) α) : Matrix (Fin m) (Fin n) α :=
step A 0 0
def augment_with_identity (A : Matrix (Fin n) (Fin n) α) : Matrix (Fin n) (Fin (n + n)) α :=
fun i j ↦ if j.val < n then A i j else (1 : Matrix (Fin n) (Fin n) α) i (j - n)
def extract_inverse (A : Matrix (Fin n) (Fin (n + n)) α) : Matrix (Fin n) (Fin n) α :=
fun i j ↦ by
have h : n + j < n + n := by simp only [add_lt_add_iff_left, Fin.is_lt]
exact A i ⟨n + j, h⟩
def matrix_inverse (A : Matrix (Fin n) (Fin n) α) : Matrix (Fin n) (Fin n) α :=
let augmented := augment_with_identity A
extract_inverse <| gaussian_elimination augmented
-- invertible example
def M := !![(2 : ℚ), -1/2; -3, 1]
def N := matrix_inverse M
#eval M*N == 1
#eval N*M == 1
-- non-invertible example
def M' := !![(2 : ℚ), -1/2; 6, -3/2]
def N' := matrix_inverse M'
#eval N'
#eval M'*N' == 1
How do I show termination? I'm not even sure that the part termination_by _ i j => (m - i) * (n - j)
is a right start. In words, I'd argue that the recursion ends because i or j increases in each step and as soon as they are are at least m or n respectively we are finished.
Scott Morrison (May 05 2023 at 22:49):
A brief answer: if you want to do termination_by
some complicated expression, you will need to have
the fact that that quantity has decreased into the context of each recursive call. That is, the fact that it is decreasing should be done inline with the code. This is contrary to the usual practice of proving things after defining them, unfortunately, and so for sufficiently complicated situations requires rethinking just using termination_by
.
Moritz Firsching (May 06 2023 at 11:48):
Thanks! That helped a lot, I'm able to show termination now (not at all golfed yet...):
def step : Matrix (Fin m) (Fin n) α → ℕ → ℕ → Matrix (Fin m) (Fin n) α
| A', i, j =>
if h : i < m ∧ j < n then
have hn : n - (j + 1) < n - j := by
refine' Nat.sub_lt_sub_left _ _
linarith
simp only [lt_add_iff_pos_right]
have hm : m - (i + 1) < m - i := by
refine' Nat.sub_lt_sub_left _ _
linarith
simp only [lt_add_iff_pos_right]
let rows_to_search := List.drop i (List.finRange m)
let pivot := List.find? (fun r => A' r ⟨j, (h.right)⟩ != 0) rows_to_search
match pivot with
| none =>
step A' i (j + 1)
| some r =>
let A'' := swap_rows A' ⟨i, h.left⟩ r
let A''' := gaussian_elimination_step A'' ⟨i, h.left⟩ ⟨j, h.right⟩
have : m - (i + 1) + (n - (j + 1)) < m - i + (n - j) := by linarith
step A''' (i + 1) (j + 1)
else A'
termination_by step _ i j => (m - i) + (n - j)
Kyle Miller (May 06 2023 at 12:12):
If you're planning on running this algorithm in practice, watch out since it will have worse asymptotic performance than you might expect. This is because the matrices are functions, so, for example swap_rows
is creating a new closure that calculates entries from scratch every time. Using arrays of arrays (or writing a function that caches all the values of a matrix as an array of arrays then creates a new Matrix
that reads from these arrays) would be a way to fix this.
Moritz Firsching (May 07 2023 at 20:15):
Thanks, @Kyle Miller that is certainly something to keep in mind. I was not really going for performance with this, but basically just playing around, because I didn't find Gaussian elimination in mathlib already.
Yakov Pechersky (May 07 2023 at 20:25):
iirc Sebastien added something related to elementary matrices a while back
Eric Wieser (May 07 2023 at 21:45):
docs#matrix.transvection is in the same file as what Yakov Pechersky is referring to
Last updated: Dec 20 2023 at 11:08 UTC