Zulip Chat Archive

Stream: maths

Topic: Trouble with tensor product and submatrix proof


Anirudh Suresh (Jun 27 2025 at 04:19):

import Mathlib
import Lean
import Lean.Elab.Tactic
import Lean.Elab.Term
import Lean.Meta
open Lean
open Lean.Elab.Tactic
open Lean.Elab.Term
open Lean.Meta
open scoped Matrix
open Matrix

def tensor_product_2 {m n o p : }
  (A : Matrix (Fin m) (Fin n) )
  (B : Matrix (Fin o) (Fin p) ) :
  Matrix (Fin (m * o)) (Fin (n * p))  :=
  (Matrix.kroneckerMap (fun (a b : ) => a * b) A B).reindex finProdFinEquiv finProdFinEquiv

-- Provide a convenient infix notation for the tensor product
infixl:80 " ⊗ " => tensor_product_2


@[reducible]
def Square (n : ) := Matrix (Fin n) (Fin n) 

def Zero_Matrix (m n : ) : Matrix (Fin m) (Fin n)  := fun _ _ => 0

def identity_matrix (n : ) : Square n :=
  fun i j => if i = j then 1 else 0

@[simp] lemma tensor_assoc {m n o p q r : }
  (A : Matrix (Fin m) (Fin n) )
  (B : Matrix (Fin o) (Fin p) )
  (C : Matrix (Fin q) (Fin r) ) :
  tensor_product_2 (tensor_product_2 A B) C =
      (tensor_product_2 A (tensor_product_2 B C)).submatrix (Fin.cast (by simp[mul_assoc])) (Fin.cast (by simp[mul_assoc])) := by {
        sorry
      }

def pad (n t dim : ) (U : Square (2 ^ n)) : Square (2 ^ dim) :=
  if H : t + n  dim then
    let foo : Square (2 ^ t * 2 ^ n * 2 ^ (dim - n - t)) :=
      (identity_matrix (2 ^ t))  U  identity_matrix (2 ^ (dim - n - t))
    -- now cast both the row-index and col-index from `2^dim` down to
    -- `2^t * 2^n * 2^(dim-n-t)` via `Fin.cast`, using `simp` to prove
    -- the two naturals are equal by repeated `pow_add` and `tsub` lemmas.
    foo.submatrix
      (Fin.cast (by {
        rw [ Nat.add_sub_of_le H]
        rw [Nat.pow_add]
        rw [Nat.pow_add]
        rw [ Nat.sub_sub]
        simp[Nat.sub_sub,Nat.add_comm]
      }))
      (Fin.cast (by {
        rw [ Nat.add_sub_of_le H]
        rw [Nat.pow_add]
        rw [Nat.pow_add]
        rw [ Nat.sub_sub]
        simp[Nat.sub_sub,Nat.add_comm]
      }))
  else
    Zero_Matrix (2 ^ dim) (2 ^ dim)


lemma pad_succ {n t dim : } {U : Square (2 ^ n)} (h : t + n  dim) :
  pad n t (dim + 1) U = pad n t dim U  identity_matrix 2 := by {
    dsimp [pad]
    split_ifs with h'
    {

    }
    {
      exact False.elim (Nat.not_le_of_gt (by linarith) h)
    }
  }

In the "pos" case of the pad_succ proof after the split_ifs, I am having trouble trying to bring the identity_matrix 2 in the rhs to tensor with identity_matrix (2 ^ (dim - n - t)) because of the submatrix. What is the best way for me to get past this issue?

Kenny Lau (Jun 27 2025 at 10:05):

@Anirudh Suresh prove separately the lemma that the submatrix of identity matrix by two equal injective maps is the corresponding identity matrix

Anirudh Suresh (Jun 27 2025 at 12:50):

import Mathlib
import Lean
import Lean.Elab.Tactic
import Lean.Elab.Term
import Lean.Meta
open Lean
open Lean.Elab.Tactic
open Lean.Elab.Term
open Lean.Meta
open scoped Matrix
open Matrix

def tensor_product_2 {m n o p : }
  (A : Matrix (Fin m) (Fin n) )
  (B : Matrix (Fin o) (Fin p) ) :
  Matrix (Fin (m * o)) (Fin (n * p))  :=
  (Matrix.kroneckerMap (fun (a b : ) => a * b) A B).reindex finProdFinEquiv finProdFinEquiv

-- Provide a convenient infix notation for the tensor product
infixl:80 " ⊗ " => tensor_product_2


@[reducible]
def Square (n : ) := Matrix (Fin n) (Fin n) 

def Zero_Matrix (m n : ) : Matrix (Fin m) (Fin n)  := fun _ _ => 0

def identity_matrix (n : ) : Square n :=
  fun i j => if i = j then 1 else 0

@[simp] lemma tensor_assoc {m n o p q r : }
  (A : Matrix (Fin m) (Fin n) )
  (B : Matrix (Fin o) (Fin p) )
  (C : Matrix (Fin q) (Fin r) ) :
  tensor_product_2 (tensor_product_2 A B) C =
      (tensor_product_2 A (tensor_product_2 B C)).submatrix (Fin.cast (by simp[mul_assoc])) (Fin.cast (by simp[mul_assoc])) := by {
        sorry
      }
@[simp] lemma tensor_assoc4 {m n o p q r s t : }
  (A : Matrix (Fin m) (Fin n) )
  (B : Matrix (Fin o) (Fin p) )
  (C : Matrix (Fin q) (Fin r) )
  (D : Matrix (Fin s) (Fin t) ) :
  tensor_product_2 (tensor_product_2 (tensor_product_2 A B) C) D =
      (tensor_product_2 A (tensor_product_2 B (tensor_product_2 C D))).submatrix (Fin.cast (by simp[mul_assoc])) (Fin.cast (by simp[mul_assoc])) := by {
        sorry
      }

def pad (n t dim : ) (U : Square (2 ^ n)) : Square (2 ^ dim) :=
  if H : t + n  dim then
    let foo : Square (2 ^ t * 2 ^ n * 2 ^ (dim - n - t)) :=
      (identity_matrix (2 ^ t))  U  identity_matrix (2 ^ (dim - n - t))
    -- now cast both the row-index and col-index from `2^dim` down to
    -- `2^t * 2^n * 2^(dim-n-t)` via `Fin.cast`, using `simp` to prove
    -- the two naturals are equal by repeated `pow_add` and `tsub` lemmas.
    foo.submatrix
      (Fin.cast (by {
        rw [ Nat.add_sub_of_le H]
        rw [Nat.pow_add]
        rw [Nat.pow_add]
        rw [ Nat.sub_sub]
        simp[Nat.sub_sub,Nat.add_comm]
      }))
      (Fin.cast (by {
        rw [ Nat.add_sub_of_le H]
        rw [Nat.pow_add]
        rw [Nat.pow_add]
        rw [ Nat.sub_sub]
        simp[Nat.sub_sub,Nat.add_comm]
      }))
  else
    Zero_Matrix (2 ^ dim) (2 ^ dim)

@[simp]
lemma submatrix_identity {m n : } {f : Fin n  Fin m} (hf : Function.Injective f) :
  (identity_matrix m).submatrix f f = identity_matrix n := by
  -- we prove it entrywise
  ext i j
  -- `submatrix` and `I` both unfold to `fun i j => if … then 1 else 0`
  dsimp [submatrix, identity_matrix]
  -- now the equality `f i = f j ↔ i = j` is exactly `hf.eq_iff`
  simp [hf.eq_iff]

lemma pad_succ {n t dim : } {U : Square (2 ^ n)} (h : t + n  dim) :
  pad n t (dim + 1) U = pad n t dim U  identity_matrix 2 := by {
    dsimp [pad]
    split_ifs with h'
    {

    }
    {
      exact False.elim (Nat.not_le_of_gt (by linarith) h)
    }
  }

But even after this, there is the issue of splitting the tensor products and then dealing with the submatrix.

Eric Wieser (Jun 27 2025 at 13:02):

Things will be easier if you use 1 instead of your identity

Eric Wieser (Jun 27 2025 at 13:03):

(since docs#Matrix.submatrix_one surely exists)

Kenny Lau (Jun 27 2025 at 13:05):

Things will also be easier if you use Fintype rather than Fin n so that you won't have to carry the isomorphism Fin m × Fin n ≃ Fin (m * n) around...

Kenny Lau (Jun 27 2025 at 13:06):

i.e. my proof would be to first transport it back to an equality in the correct fintypes, and then transport it back to Fin

Eric Wieser (Jun 27 2025 at 13:10):

It's perhaps worth merging this with the previous threads by the OP about the same thing

Eric Wieser (Jun 27 2025 at 13:11):

Though I think the 2 ^ n part makes this require a bit more thought; perhaps this is really about matrices indexed by Pi types, like I -> Fin 2?

Kenny Lau (Jun 27 2025 at 13:12):

I think if you do enough abstract linear algebra (so not like CS working with matrix) you'll realise that the order of the basis doesn't really matter

Kenny Lau (Jun 27 2025 at 13:12):

I feel like all of the definitions here are trying to bridge the maths view with the CS view and that might not be a good thing

Eric Wieser (Jun 27 2025 at 13:17):

I think the missing piece(s) here are n-ary kronecker products, defined either on commutative coefficient rings or on ordered indices

Eric Wieser (Jun 27 2025 at 13:18):

A bit like how we have docs#MultiLinearMap.mkPiAlgebraFin

Eric Wieser (Jun 27 2025 at 13:19):

I think we'll need that result to show that matrices commute with an n-ary tensor product


Last updated: Dec 20 2025 at 21:32 UTC