Zulip Chat Archive
Stream: Is there code for X?
Topic: Tensor Comprehensions
namibj (Jan 10 2024 at 01:46):
I really like the concept and general approach of FB AI Research's Tensor Comprehensions system (Repo, Docs, Paper)
for concisely expressing the semantics of (most) deep learning kernels. This starts at typical BLAS routines, but extends over the common einsum approach to expressing tensor contractions that are restricted to the normal add/mul semiring of the usual f32
primitive type.
From the paper:
def sgemm(float a, float b, float(N,M) A, float(M,K) B) → (C) {
C(i,j) = b * C(i,j) # initialization
C(i,j) += a * A(i,k) * B(k,j) # accumulation
}
def mv(float(M,K) A, float(K) x) → (C) {
C(i) +=! A(i,k) * x(k)
}
def fcrelu(float(B,I) in, float(O,I) weight, float(I) bias) → (out) {
out(i,j) = bias(j)
out(b,o) += in(b,i) * weight(o,i)
out(i,j) = fmaxf(out(i,j), 0)
}
The +=!
just means "zero init the target before performing the accumulation/reduction".
I'd love to get to use Lean to reason about and discuss/explain custom kernels, particularly including forward/reverse mode vector jacobian/jacobian vector products, getting to check/prove that the modified/transformed kernels indeed compute the supposed derivatives of the original kernel.
Ideally there'd also be a way to go further and partially specialize dimensions to work out arithmetic intensity (ratio of compute to memory access, indicating whether an implementation is memory-bound or compute-bound, at least for sufficiently regular memory access patterns) and some influence of numerical problems (e.g. catastrophic cancellation) for some concrete implementation that nominally respects the analytical definition, but is trying to offer practical performance.
I'd hope for some suggestions/pointers to how to use a spiritually similar syntax in Lean, which is at least somewhat amenable to mathlib4's analysis.
Tomas Skrivan (Jan 11 2024 at 08:29):
I haven't read the paper yet but these functions can be implemented with SciLean(library I'm working on). I'm not sure if I have implemented fcrelu
correctly either I do not understand the notation or there should be float(O) bias
.
import SciLean
open SciLean ArrayType
def sgemm (a b : Float) (A : Float^[N,M]) (B : Float^[M,K]) (C : Float^[N,K]) : Float^[N,K] :=
mapIdx (fun (i,j) cij => b * cij + ∑ k, a * A[(i,k)] * B[(k,j)]) C
def mv (A : Float^[M,K]) (x : Float^[K]) : Float^[M] :=
⊞ i => ∑ j, A[(i,j)] * x[j]
def fcrelu (in' : Float^[B,I]) (weight : Float^[O,I]) (bias : Float^[O]) : Float^[B,O] :=
⊞ (i,j) => 0 ⊔ (bias[j] + ∑ k, in'[(i,k)] * weight[(j,k)] )
If I add differentiation rules for mapIdx
and max
then SciLean can also generate derivatives of these functions and provide a "proof" of correctness . Right now, such proof would not be complete as SciLean does not prove all the elementary transformation rules and implementation of n-arrays(Float^[n,m,k]
) is not yet verified.
In future, I would like to optimize/rewrite these functions using techniques and ideas from the paper Verified Tensor-Program Optimization Via High-Level Scheduling Rewrites but I will also look at the paper you have mentioned.
However, many things you have mentioned will not be possible to prove/reason about if these function are defined as normal Lean functions. Because Lean has function extensionality you can't reason about the implementation of a function. Thus reasoning if a function is memory/compute bound would require a different approach. Lastly I have no clue how to reason about numerical problems like catastrophic cancellation.
namibj (Jan 12 2024 at 03:29):
Hmm, is it possible to reason about the memory access vs. compute aspects from the version you've written there?
I have a Julia and a C implementation of the fcrelu; the former is just visually inspected, the latter is tested to not segfault; both have been written under my directions from GPT-4 with a relevant excerpt form the paper provided as reference/semantics spec.
function fcrelu(input::Array{Float32,2}, weight::Array{Float32,2}, bias::Array{Float32,1})
# Dimension assertions
B, I = size(input)
O, I2 = size(weight)
I3 = length(bias)
@assert I == I2 && I == I3
# Initialize the output tensor 'out'
out = zeros(Float32, B, O)
# First operation: Initialize 'out' with 'bias'
for o in 1:O, b in 1:B
out[b, o] = bias[o]
end
# Second operation: Accumulate the matrix multiplication
for b in 1:B, o in 1:O, i in 1:I
out[b, o] += input[b, i] * weight[o, i]
end
# Third operation: Apply ReLU activation
for b in 1:B, o in 1:O
out[b, o] = max(out[b, o], 0)
end
return out
end
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
void fcrelu(float* input, float* weight, float* bias, float* out, int B, int I, int O) {
// Initialize 'out' with 'bias'
for (int o = 0; o < O; ++o) {
for (int b = 0; b < B; ++b) {
out[b * O + o] = bias[o];
}
}
// Accumulate the matrix multiplication
for (int b = 0; b < B; ++b) {
for (int o = 0; o < O; ++o) {
for (int i = 0; i < I; ++i) {
out[b * O + o] += input[b * I + i] * weight[o * I + i];
}
}
}
// Apply ReLU activation
for (int o = 0; o < O; ++o) {
for (int b = 0; b < B; ++b) {
out[b * O + o] = fmaxf(out[b * O + o], 0.0f);
}
}
}
int main() {
// Example usage
int B = 10; // Number of batches
int I = 5; // Input size
int O = 3; // Output size
// Allocate memory for input, weights, bias, and output
float* input = (float*)malloc(B * I * sizeof(float));
float* weight = (float*)malloc(O * I * sizeof(float));
float* bias = (float*)malloc(O * sizeof(float));
float* out = (float*)malloc(B * O * sizeof(float));
// Initialize input, weight, bias here (omitted for brevity)
// Call fcrelu
fcrelu(input, weight, bias, out, B, I, O);
// Use 'out' here (omitted for brevity)
// Free allocated memory
free(input);
free(weight);
free(bias);
free(out);
return 0;
}
Tomas Skrivan (Jan 12 2024 at 13:22):
No, you can't reason about the memory access or compute of the Lean functions I wrote.
You would have to define a specialized language, its memory model and compute cost. Only then can you reason about stuff like that.
Your Julia and C versions of fcrelu
match my Lean implementation.
Last updated: May 02 2025 at 03:31 UTC