Zulip Chat Archive

Stream: general

Topic: structure being treated as inductive datatype


Frederick Pu (Dec 13 2024 at 22:10):

structure ShapedVector (Shape : Type v) (ShapedType : Shape  Type u) (shapes : Array Shape) where
    shapedVector : Vector (Sigma (fun shape : Shape => ShapedType shape)) shapes.size
    -- we use this definition instead of the quantifier definition for the constructor so that shape matching can be provided using `rfl`
    shapesMatch : (List.finRange shapes.size).map (fun i => (shapedVector.get i).1) = (List.finRange shapes.size).map (fun i => shapes[i])

/-
    simplified version of computation graph (no weight sharing)
    AutoDiffTree α β outShape
-/
inductive AutoDiffTree (α : Type u) (β : Type v) : List Nat  Type (max u v)
-- (kernel) invalid nested inductive datatype 'ShapedVector', nested inductive datatypes parameters cannot contain local variables.
| mk
    {shapes : Array (List Nat)}
    {outShape : List Nat}
    (ctx : EFunction α β shapes outShape)
    (parents : ShapedVector (List Nat) (AutoDiffTree α β) shapes)
    (tensor : (shape : List Nat) × DualTensor α β shape) : AutoDiffTree α β outShape

ShapedVector isnt really an inductive type, so is there any way to do what im doing without just inlining the stuff from shapedvector?

Kyle Miller (Dec 14 2024 at 00:05):

Every structure is an inductive with a single constructor

Kyle Miller (Dec 14 2024 at 00:06):

The issue isn't that it's being treated as an inductive data type (you need it to be treated as an inductive data type to make any progress here), but the "nested inductive datatypes parameters cannot contain local variables" part.

Frederick Pu (Dec 14 2024 at 00:36):

so how do i fix it? it seems like this should be a perfectly sound way of defining a type

Frederick Pu (Dec 14 2024 at 01:25):

wait so the main issue is that Vector is a nested data type right? cause even when i inline it i still get the same issue

Frederick Pu (Dec 14 2024 at 01:59):

so this seemed to solve it:

inductive AutoDiffTree (α : Type u) (β : Type v) : List Nat  Type (max u v)
| mk
    {outShape : List Nat}
    {shapes : Array (List Nat)}
    (parents : (shapes : Array (List Nat)) × (ShapedVector (List Nat) (Tensor α) shapes))
    (matchesShapes : parents.1 = shapes)
    (ctx : EFunction α β shapes outShape)
    (tensor : (shape : List Nat) × DualTensor α β shape)
    : AutoDiffTree α β outShape

Frederick Pu (Dec 14 2024 at 02:00):

that way ShapedVector is never being passed a local variable

Frederick Pu (Dec 14 2024 at 18:22):

nvm that didnt really do what it was supposed to do since i jut replaced AutoDiffTree with Tensor. Here's what I actually ended up doing:

/- Data carrying part of ComputedAutoDiffTree
-/
protected inductive ComputedAutoDiffTree.DiffTree (α : Type u) (β : Type v) : List Nat  Type (max u v)
| mk
    {outShape : List Nat}
    (parents : Array (Σ shape, ComputedAutoDiffTree.DiffTree α β shape))
    (saved_tensors : Σ shapes, ShapedVector (DualTensor α β) shapes)
    (ctx : Σ shapes, EFunction α β shapes outShape)
    (tensor : DualTensor α β outShape)
    : ComputedAutoDiffTree.DiffTree α β outShape

/- parents, saved_tensors and ctx shapes match at all levels
-/
protected partial def ComputedAutoDiffTree.DiffTree.valid {α : Type u} {β : Type v} {shape : List Nat} : ComputedAutoDiffTree.DiffTree α β shape  Prop
| mk #[] saved_tensors ctx _ => saved_tensors.1 = #[]  ctx.1 = #[]
| mk parents saved_tensors ctx _ => parents.map (fun x => x.1) = saved_tensors.1  saved_tensors.1 = ctx.1   x  parents, x.2.valid

/-
    AutoDiffTree where the saved_tensors are computed
-/
structure ComputedAutoDiffTree (α : Type u) (β : Type v) (shape : List Nat) where
    diffTree : ComputedAutoDiffTree.DiffTree α β shape
    isValid : diffTree.valid

Frederick Pu (Dec 14 2024 at 18:22):

there's no real way of enforcing the constraint within the inductive definition

Frederick Pu (Dec 14 2024 at 18:23):

also does anyone know how i can show termination for the valid function? like do i need to need to modify DiffTree to be Nat -> List Nat -> Type (max u v)???


Last updated: May 02 2025 at 03:31 UTC