Zulip Chat Archive

Stream: lean4

Topic: Efficiency issue: small computation takes 10min and max mem


Xiaoning Bian (Jan 15 2024 at 19:17):

The following code translated from my Agda code runs really slow on an 8 cases distinction, and maximizes my 16GB memory . My goal was to speedup the verification of a 2^12 cases distinction (Agda checks this in about 40min). Now it seems hard to achieve. Any methods that I can use to speedup my Lean4 code?

-- import Nat
universe u
namespace VerCCZ
--abbrev ℕ := Nat
notation "ℕ" => Nat

-- Basic gates. They act on the first few qubits of an n-qubit system.
inductive Gate where
  -- one qubit gate.
  | Omega : Gate
  | I : Gate
  | Z : Gate
  | S : Gate
  | T : Gate
  | H : Gate
  | X : Gate
  | Meas : Bool -> Gate
  -- two qubit gate.
  | CX : Gate
  | XC : Gate -- Controlled X gate with target on 0th qubit and control on 1st qubit.
  | CZ : Gate
  | CS : Gate
  | Swap :  -> Gate -- Swap gate on j-th and (j+1)-th qubits.
  -- three qubit gate.
  | CCZ : Gate
  | CCX : Gate -- Toffoli gate.

-- Extended gates. They act on any qubits of an n-qubit system.
inductive GateExt where
  -- one qubit gate.
  | Omega :  -> GateExt
  | I :  -> GateExt
  | Z :  -> GateExt
  | S :  -> GateExt
  | T :  -> GateExt
  | H :  -> GateExt
  | X :  -> GateExt
  | Meas : Bool ->  -> GateExt
  -- two qubit gate.
  | CX :  ->  -> GateExt
  | XC :  ->  -> GateExt
  | CZ :  ->  -> GateExt
  | CS :  ->  -> GateExt
  | Swap :  ->  -> GateExt
  -- three qubit gate.
  | CCZ :  ->  ->  -> GateExt
  | CCX :  ->  ->  -> GateExt -- Toffoli gate.

inductive Ket :  -> Type where
  | Ket1 : Bool -> (Ket 1)
  | KetS : {n : } -> Ket 1 -> Ket n -> Ket (.succ n)
--  deriving Repr

open Ket

notation "K0" => Ket1 false
notation "K1" => Ket1 true

notation "K00" => KetS K0 K0
notation "K01" => KetS K0 K1
notation "K10" => KetS K1 K0
notation "K11" => KetS K1 K1

notation "K000" => KetS K0 K00
notation "K001" => KetS K0 K01
notation "K010" => KetS K0 K10
notation "K011" => KetS K0 K11
notation "K100" => KetS K1 K00
notation "K101" => KetS K1 K01
notation "K110" => KetS K1 K10
notation "K111" => KetS K1 K11


notation "K000000" => KetS K0 (KetS K0 (KetS K0 (KetS K0 (KetS K0 (Ket1 false)))))
notation "K000011" => KetS K0 (KetS K0 (KetS K0 (KetS K0 (KetS K1 (Ket1 true)))))
notation "K001100" => KetS K0 (KetS K0 (KetS K1 (KetS K1 (KetS K0 (Ket1 false)))))
notation "K001111" => KetS K0 (KetS K0 (KetS K1 (KetS K1 (KetS K1 (Ket1 true)))))
notation "K110000" => KetS K1 (KetS K1 (KetS K0 (KetS K0 (KetS K0 (Ket1 false)))))
notation "K110011" => KetS K1 (KetS K1 (KetS K0 (KetS K0 (KetS K1 (Ket1 true)))))
notation "K111100" => KetS K1 (KetS K1 (KetS K1 (KetS K1 (KetS K0 (Ket1 false)))))
notation "K111111" => KetS K1 (KetS K1 (KetS K1 (KetS K1 (KetS K1 (Ket1 true)))))

notation "K000000000" => KetS K0 (KetS K0 (KetS K0 K000000))
notation "K001000000" => KetS K0 (KetS K0 (KetS K1 K000000))
notation "K010000000" => KetS K0 (KetS K1 (KetS K0 K000000))
notation "K011000000" => KetS K0 (KetS K1 (KetS K1 K000000))
notation "K100000000" => KetS K1 (KetS K0 (KetS K0 K000000))
notation "K101000000" => KetS K1 (KetS K0 (KetS K1 K000000))
notation "K110000000" => KetS K1 (KetS K1 (KetS K0 K000000))
notation "K111000000" => KetS K1 (KetS K1 (KetS K1 K000000))

-- We only use omega ^ k, where k = 0,1,...,7.
inductive Scalar where
  | omega :   Scalar
open Scalar

deriving instance BEq, Hashable, Repr for Scalar
deriving instance BEq, Hashable, Repr for Gate
deriving instance BEq, Hashable, Repr for GateExt
deriving instance BEq, Hashable, Repr for Ket


-- Scalar multiplication.
def smul : Scalar -> Scalar -> Scalar
  | (omega a),  (omega b) => omega (a * b % 8)

instance : Mul Scalar where
  mul := smul


notation "P1" => omega 0 -- 1
notation "Pw" => omega 1 -- omega
notation "M1" => omega 4 -- -1
notation "Mw" => omega 5 -- -omega
notation "Pi" => omega 2 -- i
notation "Mi" => omega 6 -- -i

open Ket
open List
open Gate

-- act works for n >= 4. k' in the 1-qubit gate case is of
-- (n-1)-qubit, and 2-qubit gate case, (n-2)-qubit, and 3-qubit gate
-- case, (n-3)-qubit. k' cannot be 0-qubit.
def act {n : } (g : Gate) (k : Ket n) : List (Scalar × Ket n) :=
  match g, k with
    | I , k@(KetS K0 k') => (P1 , k) :: []
    | I, k@(KetS K1 k') => (P1 , k) :: []
    | Omega, k@(KetS K0 k') => (Pw , k) :: []
    | Omega, k@(KetS K1 k') => (Pw , k) :: []
    | Z, k@(KetS K0 k') => (P1 , k) :: []
    | Z, k@(KetS K1 k') => (M1 , k) :: []
    | S, k@(KetS K0 k') => (P1 , k) :: []
    | S, k@(KetS K1 k') => (Pi , k) :: []
    | T, k@(KetS K0 k') => (P1 , k) :: []
    | T, k@(KetS K1 k') => (Pw , k) :: []
    | H, k@(KetS K0 k') => (P1 , KetS K0 k') :: (P1 , KetS K1 k') :: []
    | H, k@(KetS K1 k') => (P1 , KetS K0 k') :: (M1 , KetS K1 k') :: []
    | X, k@(KetS K0 k') => (P1 , KetS K1 k') :: []
    | X, k@(KetS K1 k') => (P1 , KetS K0 k') :: []
    | (Meas false), k@(KetS K0 k') => (P1 , k) :: []
    | (Meas true), k@(KetS K1 k') => (P1 , k) :: []
    | CZ, k@(KetS K0 (KetS K0 k')) => (P1 , k) :: []
    | CZ, k@(KetS K0 (KetS K1 k')) => (P1 , k) :: []
    | CZ, k@(KetS K1 (KetS K0 k')) => (P1 , k) :: []
    | CZ, k@(KetS K1 (KetS K1 k')) => (M1 , k) :: []
    | CS, k@(KetS K0 (KetS K0 k')) => (P1 , k) :: []
    | CS, k@(KetS K0 (KetS K1 k')) => (P1 , k) :: []
    | CS, k@(KetS K1 (KetS K0 k')) => (P1 , k) :: []
    | CS, k@(KetS K1 (KetS K1 k')) => (Pi , k) :: []
    | CX, k@(KetS K0 (KetS K0 k')) => (P1 , k) :: []
    | CX, k@(KetS K0 (KetS K1 k')) => (P1 , k) :: []
    | CX, k@(KetS K1 (KetS K0 k')) => (P1 , (KetS K1 (KetS K1 k'))) :: []
    | CX, k@(KetS K1 (KetS K1 k')) => (P1 , (KetS K1 (KetS K0 k'))) :: []
    | XC, k@(KetS K0 (KetS K0 k')) => (P1 , k) :: []
    | XC, k@(KetS K0 (KetS K1 k')) => (P1 , (KetS K1 (KetS K1 k'))) :: []
    | XC, k@(KetS K1 (KetS K0 k')) => (P1 , k) :: []
    | XC, k@(KetS K1 (KetS K1 k')) => (P1 , (KetS K0 (KetS K1 k'))) :: []
    | (Swap 0), K00 => (P1 , K00) :: []
    | (Swap 0), K01 => (P1 , K10) :: []
    | (Swap 0), K10 => (P1 , K01) :: []
    | (Swap 0), K11 => (P1 , K11) :: []
    | (Swap 0), k@(KetS K0 (KetS K0 k')) => (P1 , k) :: []
    | (Swap 0), k@(KetS K0 (KetS K1 k')) => (P1 , (KetS K1 (KetS K0 k'))) :: []
    | (Swap 0), k@(KetS K1 (KetS K0 k')) => (P1 , (KetS K0 (KetS K1 k'))) :: []
    | (Swap 0), k@(KetS K1 (KetS K1 k')) => (P1 , k) :: []
    | (Swap (.succ j)), k@(KetS k1 kj) => map (fun p => let (p1 , p2) := p
      (p1 , KetS k1 p2)) (act (Swap j) kj)
    | (CCZ), k@(KetS K0 (KetS K0 (KetS K0 k'))) => (P1 , k) :: []
    | (CCZ), k@(KetS K0 (KetS K0 (KetS K1 k'))) => (P1 , k) :: []
    | (CCZ), k@(KetS K0 (KetS K1 (KetS K0 k'))) => (P1 , k) :: []
    | (CCZ), k@(KetS K0 (KetS K1 (KetS K1 k'))) => (P1 , k) :: []
    | (CCZ), k@(KetS K1 (KetS K0 (KetS K0 k'))) => (P1 , k) :: []
    | (CCZ), k@(KetS K1 (KetS K0 (KetS K1 k'))) => (P1 , k) :: []
    | (CCZ), k@(KetS K1 (KetS K1 (KetS K0 k'))) => (P1 , k) :: []
    | (CCZ), k@(KetS K1 (KetS K1 (KetS K1 k'))) => (M1 , k) :: []
    | (CCX), k@(KetS K0 (KetS K0 (KetS K0 k'))) => (P1 , k) :: []
    | (CCX), k@(KetS K0 (KetS K0 (KetS K1 k'))) => (P1 , k) :: []
    | (CCX), k@(KetS K0 (KetS K1 (KetS K0 k'))) => (P1 , k) :: []
    | (CCX), k@(KetS K0 (KetS K1 (KetS K1 k'))) => (P1 , k) :: []
    | (CCX), k@(KetS K1 (KetS K0 (KetS K0 k'))) => (P1 , k) :: []
    | (CCX), k@(KetS K1 (KetS K0 (KetS K1 k'))) => (P1 , k) :: []
    | (CCX), k@(KetS K1 (KetS K1 (KetS K0 k'))) => (P1 , KetS K1 (KetS K1 (KetS K1 k'))) :: []
    | (CCX), k@(KetS K1 (KetS K1 (KetS K1 k'))) => (P1 , KetS K1 (KetS K1 (KetS K0 k'))) :: []
    | a, b => []

def scale {n : } (s1 : Scalar) (s2k : Scalar × Ket n) : Scalar × Ket n :=
  match s1, s2k with
    | s1, (s2 , ket) => (s1 * s2 , ket)


def actsigned {n : } (g : Gate) (sk : Scalar × Ket n) : List (Scalar × Ket n) :=
  match g, sk with
    | g, (s, ket) => map (scale s) (act g ket)

def actlist {n : } (g : Gate) (l : List (Scalar × Ket n)) : List (Scalar × Ket n) :=
  match g, l with
    | g, ls => join (map (fun x => actsigned g x) ls)

def actcir {n : } (l : List Gate) (ls : List (Scalar × Ket n)) : List (Scalar × Ket n) :=
  match l, ls with
    | [], lsk => lsk
    | (h :: t), lsk => actcir t (actlist h lsk)

Xiaoning Bian (Jan 15 2024 at 19:17):

open Nat
-- generate a list of Swap gate, which equals Swap j l.
def gswap :  ->  -> List Gate
  | 0, 0 => []
  | 0, 1 => Swap 0 :: []
  | 0, (.succ j) => Swap j :: gswap 0 j ++ Swap j :: []
  | 1, 0 => Swap 0 :: []
  | (.succ j), 0 => Swap j :: gswap j 0 ++ Swap j :: []
  | (.succ j), (.succ l) =>
    if Nat.blt (.succ j) l then Swap l :: gswap (.succ j) l ++ Swap l :: [] else
    if Nat.beq l (.succ j) then Swap (.succ j) :: [] else
    if Nat.beq j l then [] else Swap j :: gswap j l ++ Swap j :: []

-- Desugar GateExt to Gate.
def desugar : GateExt -> List Gate
  | (GateExt.Omega j) => gswap 0 j ++ Omega :: gswap 0 j
  | (GateExt.I j) => gswap 0 j ++ I :: gswap 0 j
  | (GateExt.Z j) => gswap 0 j ++ Z :: gswap 0 j
  | (GateExt.S j) => gswap 0 j ++ S :: gswap 0 j
  | (GateExt.T j) => gswap 0 j ++ T :: gswap 0 j
  | (GateExt.H j) => gswap 0 j ++ H :: gswap 0 j
  | (GateExt.X j) => gswap 0 j ++ X :: gswap 0 j
  | (GateExt.Meas b j) => gswap 0 j ++ Meas b :: gswap 0 j
  | (GateExt.CX j l) => gswap 0 j ++ gswap 1 l ++ CX :: gswap 1 l ++ gswap 0 j
  | (GateExt.XC j l) => gswap 0 j ++ gswap 1 l ++ XC :: gswap 1 l ++ gswap 0 j
  | (GateExt.CZ j l) => gswap 0 j ++ gswap 1 l ++ CZ :: gswap 1 l ++ gswap 0 j
  | (GateExt.CS j l) => gswap 0 j ++ gswap 1 l ++ CS :: gswap 1 l ++ gswap 0 j
  | (GateExt.Swap j l) => gswap j l
  | (GateExt.CCZ j l m) => gswap 0 j ++ gswap 1 l ++ gswap 2 m ++ CCZ :: gswap 2 m ++ gswap 1 l ++ gswap 0 j
  | (GateExt.CCX j l m) => gswap 0 j ++ gswap 1 l ++ gswap 2 m ++ CCX :: gswap 2 m ++ gswap 1 l ++ gswap 0 j

def actext {n : } (e : List GateExt) (k : List (Scalar × Ket n)) : List (Scalar × Ket n) :=
  actcir (join (map desugar e)) k

def enc (a0 b0 a1 b1 a2 b2 : Bool) : List GateExt :=
    (if b0 then GateExt.Z 0 :: [] else []) ++
    (if a0 then GateExt.X 0 :: [] else []) ++
    (if b1 then GateExt.Z 1 :: [] else []) ++
    (if a1 then GateExt.X 1 :: [] else []) ++
    (if b2 then GateExt.Z 2 :: [] else []) ++
    (if a2 then GateExt.X 2 :: [] else [])

def copy (c0 c1 c2 : Bool) : List GateExt :=
  (GateExt.H 3 :: GateExt.CX 3 4 :: GateExt.H 5 :: GateExt.CX 5 6 :: GateExt.H 7 :: GateExt.CX 7 8 :: GateExt.XC 0 3 :: GateExt.XC 1 5 :: GateExt.XC 2 7 :: []) ++
  (GateExt.Meas c0 0 :: GateExt.Meas c1 1 :: GateExt.Meas c2 2 :: []) ++
  (GateExt.CX 0 3 :: GateExt.CX 0 4  :: GateExt.CX 1 5 :: GateExt.CX 1 6 :: GateExt.CX 2 7 :: GateExt.CX 2 8 :: [])

def dec (a0 b0 a1 b1 a2 b2 : Bool) : List GateExt :=
    (if a0 then GateExt.X 3 :: [] else []) ++
    (if b0 then GateExt.Z 3 :: [] else []) ++
    (if a1 then GateExt.X 5 :: [] else []) ++
    (if b1 then GateExt.Z 5 :: [] else []) ++
    (if a2 then GateExt.X 7 :: [] else []) ++
    (if b2 then GateExt.Z 7 :: [] else [])


def myxor : Bool -> Bool -> Bool
  | false, false => false
  | false, true => true
  | true, false => true
  | true, true => false

infixl:55 " ^^ " => myxor

-- 2^12 = 4096 circuits for QHE for CCZ gate.
-- def cczcondcir (a0 b0 c0 k0 a1 b1 c1 k1 a2 b2 c2 k2 : Bool) : List GateExt := sorry

-- Discard qubit 0,1,2,4,6,8.
def discard : Ket 9  Ket 3
  | (KetS b0 (KetS b1 (KetS b2 (KetS b3 (KetS b4 (KetS b5 (KetS b6 (KetS b7 b8)))))))) => KetS b3 (KetS b5 (b7))

-- Discard qubit 0,1,2.
def discard2 : Ket 9  Ket 6
  | (KetS b0 (KetS b1 (KetS b2 (KetS b3 (KetS b4 (KetS b5 (KetS b6 (KetS b7 b8)))))))) =>(KetS b3 (KetS b4 (KetS b5 (KetS b6 (KetS b7 b8)))))

def discard_signed : Scalar × Ket 9  Scalar × Ket 3
  | (sign , ket9) => (sign , discard ket9)

def discard_signed2 : Scalar × Ket 9  Scalar × Ket 6
 | (sign , ket9) => (sign , discard2 ket9)

-- Definition of "implementing a scaled CCZ gate". Ignoring the
-- measured qubit and ignoring the four state |0> qubits for Bell
-- state, implementing a CCZ gate on the first two qubits means:

-- CCZ |111 000000> to -|111>
-- CCZ |abc 000000> to  |abc>  when not all a b c are 1.

-- The scaled version for some scalar s is:

-- CCZ |111 000000> to - s |111>
-- CCZ |abc 000000> to   s |abc>  when not all a b c are 1.


def _maps_to_ : List GateExt  Ket 9  Scalar × Ket 3  Prop
  | cir, a, b => map discard_signed (actext cir ((P1 , a) :: [])) = b :: []

def _maps2_to_ : List GateExt  Ket 9  Scalar × Ket 6  Prop
  | cir, a, b => map discard_signed2 (actext cir ((P1 , a) :: [])) = b :: []


def _implements_scaled_CCZ : List GateExt  Prop
  | cir =>
   (s : Scalar),
  _maps_to_ cir K000000000 (s , K000) 
  _maps_to_ cir K001000000 (s , K001) 
  _maps_to_ cir K010000000 (s , K010) 
  _maps_to_ cir K011000000 (s , K011) 
  _maps_to_ cir K100000000 (s , K100) 
  _maps_to_ cir K101000000 (s , K101) 
  _maps_to_ cir K110000000 (s , K110) 
  _maps_to_ cir K111000000 (s * M1 , K111)


def _implements_copy : List GateExt  Prop
  | cir =>
    _maps2_to_ cir K000000000 (P1 , K000000) 
    _maps2_to_ cir K001000000 (P1 , K000011) 
    _maps2_to_ cir K010000000 (P1 , K001100) 
    _maps2_to_ cir K011000000 (P1 , K001111) 
    _maps2_to_ cir K100000000 (P1 , K110000) 
    _maps2_to_ cir K101000000 (P1 , K110011) 
    _maps2_to_ cir K110000000 (P1 , K111100) 
    _maps2_to_ cir K111000000 (P1 , K111111)

open Eq
infixr:55 " ,, " => And.intro

theorem lemma1 : (a : Bool) -> a = a  a = a  a = a
  | true =>  show _ from refl _ ,, refl _ ,, refl _
  | false => show _ from refl _ ,, refl _ ,, refl _

set_option maxRecDepth 12000
set_option maxHeartbeats 800000
theorem lemma_copy : (c0 c1 c2 : Bool)  (_implements_copy (copy c0 c1 c2))
  | false, false, false => show _ from refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _
  | false, false, true => show _ from refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _
  | false, true, false => show _ from refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _
  | false, true, true => show _ from refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _
  | true, false, false => show _ from refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _
  | true, false, true => show _ from refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _
  | true, true, false => show _ from refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _
  | true, true, true => show _ from refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _ ,, refl _

#check lemma_copy

/-
-- Theorem: QHE scheme for CCZ works, i.e., for all skey a0,b0,a1,b1,a2,b2
-- and for all measurement result c0,k0,c1,k1,c2,k2 the circuit always
-- implements a scaled CCZ gate.
Theorem QHE_CCZ : ∀ (a0 b0 c0 k0 a1 b1 c1 k1 a2 b2 c2 k2 : Bool) →  _implements-scaled-CCZ (cczcond-cir a0 b0 c0 k0 a1 b1 c1 k1 a2 b2 c2 k2)
  | false, false, false, false, false, false, false, false, false, false, false, false => show _ from Exists.intro P1 (refl , refl , refl , refl , refl , refl , refl , refl
-/

James Gallicchio (Jan 16 2024 at 04:07):

I haven't looked at all at your code, but is it possible to encode the theorem efficiently in SAT? We've been working on ways to send off proof goals to SAT solvers and get verified proofs back, as well as transport those SAT results to other problems encoded in SAT. Might work for your use case?

Xiaoning Bian (Jan 16 2024 at 14:48):

James Gallicchio said:

I haven't looked at all at your code, but is it possible to encode the theorem efficiently in SAT? We've been working on ways to send off proof goals to SAT solvers and get verified proofs back, as well as transport those SAT results to other problems encoded in SAT. Might work for your use case?

I don't know, but now I cannot see how SAT helps. The equalities I want to prove is easy enough: 8 equalities of the form 'a = b0', proved by simplifying 'a' and the simplification (actext in my code) is easy and short. In Agda, it takes about 1min and small memory. I don't want that SAT gives me a clever solution, I just want Lean4 to do the computation and check if the result is the given 'b0' .


Last updated: May 02 2025 at 03:31 UTC