Home Notes on Kolmogorov–Arnold Networks
Post
Cancel

Notes on Kolmogorov–Arnold Networks

Reading the original code of the KAN: Kolmogorov-Arnold Networks paper was a bit confusing for me so I tried to implement a simple KAN from scratch. You can find the original code on pykan repo.

The code written in this post is hosted here.

This post is unfinished. Will be completed. { : .prompt-info }

B-spline Curves

Theory

reference: B-splines curves and surfaces

Basis Functions

Basis functions are basic building blocks of all the functions in a function space. Every function in the function space can be represented as a linear combination of its basis functions.

Think of exponentials that are basis functions of any periodic function that has a fourier series.

Spline

A spline is a piecewise polynomial function \(S: [a, b] \rightarrow \mathbb{R}\) with \(m+1\) knots \(t = [t_0, ..., t_m]\) where \(a=t_0\le t_1 \le ... \le t_m=b\) and:

\[S(x) = P_i(x), \qquad x\in[t_i, t_{i+1})\]

for \(i=0,..,m-1\).

If the polynomial pieces \(P_i\) each have degree at most \(p\), then the spline is said to be of degree \(p\) or of order \(p+1\).

B-spline

A B-spline (basis spline) of order \(p+1\) is a piecewise polynomial function \(B_{i, k}(x)\) of degree \(p\). B-splines are basis functions for spline functions of the same order and are defined as:

\[B_{i, 0}(x) = 1,\qquad if~x\in[t_i,t_{i+1}),~o.w.~0\]

and recursively:

\[B_{i, p}(x) = \frac{x - t_i}{t_{i+p}-t_i}B_{i, p-1}(x) + \frac{t_{i+p+1} - x}{t_{i+p+1}-t_{i+1}}B_{i+1, p-1}(x)\]

for \(p > 0\), and \(i=0,...,n-1\), and \(t = [t_0, ..., t_{n+p}]\).

B-spline Curve

A B-spline curve is defined as a linear combination of B-spline basis functions with the knot vector \(t = [t_0,..,t_p,..,t_n,..t_{n+p}]\):

\[S(x) = \sum_{i=0}^{n - 1} c_i B_{i, p}(x), \qquad x\in[t_{p}, t_{n}]: \text{internal knots}\]

with \(n > p\).

Implementation

1
2
import numpy as np
from matplotlib import pyplot as plt

Scaler

1
2
3
4
5
6
7
8
9
10
11
def b(x, i, p, t):
    if p == 0:
        return 1 if t[i] <= x < t[i + 1] else 0
    c_1 = (x - t[i]) / (t[i + p] - t[i])
    c_2 = (t[i + p + 1] - x) / (t[i + p + 1] - t[i + 1])
    b_ip = c_1 * b(x, i, p - 1, t) + c_2 * b(x, i + 1, p - 1, t)
    return b_ip


def b_curve(x, c, p, t):
    return sum([ci * b(x, i, p, t) for i, ci in enumerate(c)])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
p = 3  # dgree of polynomials
n = 6  # number of basis functions (> p)
m = n + p + 1  # number of knots
t = np.linspace(0, 1, m)  # uniform knot vector (t[p:n+1] are called internal knots)
x = np.linspace(t[0], t[-1], 100)

fig, ax = plt.subplots(figsize=(16, 4))
for i in range(n):
    y = [b(xi, i, p, t) for xi in x]
    ax.plot(x, y, label=f"i={i}")
ax.axvspan(t[p], t[n], color="gray", alpha=0.3, label="Internal Knots")
ax.set_title(
    "B-spline basis functions of degree 3 with Uniform Knots", fontsize="medium"
)
ax.legend(fontsize="small")
fig.tight_layout()

png

1
2
3
4
5
6
7
8
9
10
fig, ax = plt.subplots(4, 4, figsize=(12, 12), sharex=True, sharey=True)
ax = ax.flatten()

for i in range(16):
    c = np.random.random(n)
    y = [b_curve(xi, c, p, t) for xi in x]
    ax[i].plot(x, y)
    ax[i].axvspan(t[p], t[n], color="gray", alpha=0.3)
fig.suptitle("Random B-spline curves of degree 3 with uniform knots")
fig.tight_layout()

png

Batch

If x is a batch of shape $(N, d)$, with simple modification we can have the batched version of the basis functions.

1
2
3
4
5
6
7
8
def b_batch(x, i, p, t):
    if p == 0:
        out = (t[i] <= x) * (x < t[i + 1])
        return out.astype(int)
    c_1 = (x - t[i]) / (t[i + p] - t[i])
    c_2 = (t[i + p + 1] - x) / (t[i + p + 1] - t[i + 1])
    b_ip = c_1 * b_batch(x, i, p - 1, t) + c_2 * b_batch(x, i + 1, p - 1, t)
    return b_ip

Batch for all bases

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# For all i = 0, ..., n-1
def b_vectorized(x, p, t):
    m = t.shape[0]
    n = m - p - 1
    t = t.reshape(-1, 1, 1)
    if p == 0:
        out = (t[:-1] <= x) * (x < t[1:])
        return out.astype(int)

    c_1 = (x - t[:n]) / (t[p : p + n] - t[:n])
    c_2 = (t[p + 1 : n + p + 1] - x) / (t[p + 1 : n + p + 1] - t[1 : n + 1])
    b_pm1 = b_vectorized(x, p - 1, t)
    b_p = c_1 * b_pm1[:-1] + c_2 * b_pm1[1:]
    return b_p
1
2
3
4
5
6
7
8
9
p = 3  # dgree
n = 6  # number of basis functions (> p)
m = n + p + 1  # number of knots
N = 16  # batch size
d = 8  #
t = np.linspace(0, 1, m)  # uniform knot vector (t[p:n+1] are called internal knots)
x = np.random.random((N, d))
B = b_vectorized(x, p, t)
B.shape
1
(6, 16, 8)

Batch for all bases with different knot vectors

1
2
3
4
5
6
7
8
9
10
11
12
13
def b_vectorized(x, p, t):
    m, d = t.shape
    n = m - p - 1
    t_ = t.reshape(m, 1, d)
    if p == 0:
        out = (t_[:-1] <= x) * (x < t_[1:])
        return out.astype(int)

    c_1 = (x - t_[:n]) / (t_[p : p + n] - t_[:n])
    c_2 = (t_[p + 1 : n + p + 1] - x) / (t_[p + 1 : n + p + 1] - t_[1 : n + 1])
    b_pm1 = b_vectorized(x, p - 1, t)
    b_p = c_1 * b_pm1[:-1] + c_2 * b_pm1[1:]
    return b_p
1
2
3
4
5
6
7
8
9
10
11
p = 3  # dgree
n = 6  # number of basis functions (> p)
m = n + p + 1  # number of knots
N = 16  # batch size
d = 8  #
t = np.sort(
    np.random.random((m, d)), axis=1
)  # uniform knot vector (t[p:n+1] are called internal knots)
x = np.random.random((N, d))
B = b_vectorized(x, p, t)
B.shape
1
(6, 16, 8)
This post is licensed under CC BY 4.0 by the author.