<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://alexshtf.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://alexshtf.github.io/" rel="alternate" type="text/html" /><updated>2026-04-24T06:42:37+00:00</updated><id>https://alexshtf.github.io/feed.xml</id><title type="html">Alex Shtoff</title><subtitle>Blog on optimization, machine learning, and software development.</subtitle><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><entry><title type="html">Cheaper eigenvalue training and inference</title><link href="https://alexshtf.github.io/2026/03/15/Spectrum-Banded.html" rel="alternate" type="text/html" title="Cheaper eigenvalue training and inference" /><published>2026-03-15T00:00:00+00:00</published><updated>2026-03-15T00:00:00+00:00</updated><id>https://alexshtf.github.io/2026/03/15/Spectrum-Banded</id><content type="html" xml:base="https://alexshtf.github.io/2026/03/15/Spectrum-Banded.html"><![CDATA[<p align="center">
  <a href="https://colab.research.google.com/github/alexshtf/alexshtf.github.io/blob/master/assets/spectrum_power_tridiagonal.ipynb" target="_blank" rel="noopener">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" />
  </a>
</p>

<h1 id="intro">Intro</h1>

<p>In the last post we discussed the meaning of our model family</p>

\[f({\boldsymbol x};{\boldsymbol A}_{0..n}) = \lambda_k \Bigl({\boldsymbol A}_0 + \sum_{i=1}^n x_i {\boldsymbol A}_i\Bigr),\]

<p>where each \(\boldsymbol A_i\) is a symmetric matrix. In the last post we discussed what these models predict, and how we can explain them to ourselves and other stakeholders. Before that, we also discussed GPU acceleration to make training and inference faster. Speed is important, but so is <em>cost</em>, and fast GPUs may be expensive. Here, our aim is not only to make it faster, but also cheaper, by making the eigenvalue problem easier to solve even on weaker hardware. We certainly should not be paying for a GPU and waiting more than 5 minutes to train <em>one neuron</em> on a tabular dataset with about 20k rows, even if this one neuron is a fairly complex one! We begin our exploration from theory, which immediately yields practical applications. And as always, we have a <a href="https://github.com/alexshtf/alexshtf.github.io/blob/master/assets/spectrum_power_tridiagonal.ipynb">notebook</a> to reproduce all experiments in this post.</p>

<h1 id="simultaneous-simplification">Simultaneous simplification</h1>

<p>Recall that for any orthogonal matrix \({\boldsymbol Q} \in \mathbb{R}^{d \times d}\), we have</p>

\[\lambda_k(\boldsymbol A) = \lambda_k({\boldsymbol Q}^T \boldsymbol A {\boldsymbol Q}),\]

<p>So our model family is invariant under such orthogonal similarity transformations, meaning a model with matrices \(\boldsymbol A_i\) is identical to a model with matrices \(\boldsymbol Q^T \boldsymbol A_i \boldsymbol Q\) for any orthogonal \(\boldsymbol Q\).</p>

<p>One of the interesting phenomena in linear algebra is <em>simultaneous diagonalization</em>. A set of matrices \({\boldsymbol A}_i\) is simultaneously diagonalizable if there exists an orthogonal matrix \({\boldsymbol Q}\) such that \({\boldsymbol Q}^T {\boldsymbol A}_i {\boldsymbol Q}\) is diagonal for all \(i\). In other words, the same matrix \(\boldsymbol Q\) diagonalizes all matrices simultaneously.</p>

<p>If we restrict ourselves to models where all of our learned matrices are simultaneously diagonalizable, we can just assume all matrices are diagonal:</p>

\[f({\boldsymbol x};{\boldsymbol A}_{0:n}) = \lambda_k \Bigl(\operatorname{diag}({\boldsymbol a}_0) + \sum_{i=1}^n x_i \operatorname{diag}({\boldsymbol a}_i)\Bigr).\]

<p>So what is the \(k\)-th eigenvalue of this matrix? It’s just the \(k\)-th smallest entry of the vector</p>

\[{\boldsymbol a}_0 + \sum_{i=1}^n x_i {\boldsymbol a}_i.\]

<p>On the one hand, it’s an extremely easy eigenvalue problem. But we actually lost almost all of the expressive power, since it’s just a convoluted way to describe a piecewise linear function of \({\boldsymbol x}\). We have ReLU networks for that.</p>

<p>But there is another family of matrices for which the eigenvalue problem is easy - <em>symmetric tridiagonal</em> matrices, meaning, matrices of the form:</p>

\[\mathcal{T}(\boldsymbol a, \boldsymbol b) = 
\begin{pmatrix}
a_1    &amp; b_1    &amp; 0      &amp; \dots  &amp; 0      \\
b_1    &amp; a_2    &amp; b_2    &amp; \dots  &amp; 0      \\
0      &amp; b_2    &amp; a_3    &amp; \ddots &amp; \vdots \\
\vdots &amp; \vdots &amp; \ddots &amp; \ddots &amp; b_{n-1} \\
0      &amp; 0      &amp; \dots  &amp; b_{n-1} &amp; a_n
\end{pmatrix}.\]

<p>Such a matrix is defined by two vectors, the main diagonal \(\boldsymbol a \in \mathbb{R}^d\), and the off-diagonal \(\boldsymbol b \in \mathbb{R}^{d-1}\). Turns out this family strikes a nice balance - eigenvalues of such matrices are efficient to compute, while remaining fairly expressive. Efficiency comes from standing on the shoulders of giants, and using decades of numerical analysis research, given to us in the form of <code class="language-plaintext highlighter-rouge">scipy.linalg.eigh_tridiagonal</code> on a silver platter.</p>

<p>To appreciate the speed difference, let’s time eigenvalue and eigenvector computation using SciPy for regular dense matrices, and compare it to tridiagonal matrices. Let’s create a batch of dense matrices:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="kn">import</span> <span class="nn">scipy.linalg</span> <span class="k">as</span> <span class="n">sla</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="c1"># batch of 50 matrices of size 100x100
</span><span class="n">M</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">50</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
</code></pre></div></div>
<p>If you recall from previous posts - we need eigenvectors, in addition to eigenvalues, to compute gradients to train. Now let’s measure eigenvalue and eigenvector computation time. Here it is for eigenvalues:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">timeit</span> <span class="o">-</span><span class="n">n</span> <span class="mi">100</span> <span class="o">-</span><span class="n">r</span> <span class="mi">30</span>
<span class="n">sla</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">subset_by_index</span><span class="o">=</span><span class="p">(</span><span class="mi">50</span><span class="p">,</span> <span class="mi">50</span><span class="p">)).</span><span class="nb">sum</span><span class="p">()</span>  <span class="c1"># 50-th eigenvalue
</span></code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>32.7 ms ± 4.42 ms per loop (mean ± std. dev. of 30 runs, 100 loops each)
</code></pre></div></div>
<p>Here it is for eigenvectors:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">timeit</span> <span class="o">-</span><span class="n">n</span> <span class="mi">100</span> <span class="o">-</span><span class="n">r</span> <span class="mi">30</span>
<span class="n">vals</span><span class="p">,</span> <span class="n">vecs</span> <span class="o">=</span> <span class="n">sla</span><span class="p">.</span><span class="n">eigh</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">subset_by_index</span><span class="o">=</span><span class="p">(</span><span class="mi">50</span><span class="p">,</span> <span class="mi">50</span><span class="p">))</span>
<span class="n">vecs</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>34.5 ms ± 4.5 ms per loop (mean ± std. dev. of 30 runs, 100 loops each)
</code></pre></div></div>
<p>Alright. Now let’s do it for tridiagonal matrices. First, we generate diagonal and off-diagonal vectors:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># A batch of 50 diagonal and 50 off-diagonal vectors for 100x100 matrices.
</span><span class="n">d</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">50</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">e</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">50</span><span class="p">,</span> <span class="mi">99</span><span class="p">)</span>
</code></pre></div></div>
<p>Now let’s measure. Here is eigenvalue measurement:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">timeit</span> <span class="o">-</span><span class="n">n</span> <span class="mi">100</span> <span class="o">-</span><span class="n">r</span> <span class="mi">30</span>
<span class="n">sla</span><span class="p">.</span><span class="n">eigvalsh_tridiagonal</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">e</span><span class="p">,</span> <span class="n">select</span><span class="o">=</span><span class="s">'i'</span><span class="p">,</span> <span class="n">select_range</span><span class="o">=</span><span class="p">(</span><span class="mi">50</span><span class="p">,</span> <span class="mi">50</span><span class="p">)).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>5.11 ms ± 60.1 µs per loop (mean ± std. dev. of 30 runs, 100 loops each)
</code></pre></div></div>
<p>Here is eigenvector measurement:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">timeit</span> <span class="o">-</span><span class="n">n</span> <span class="mi">100</span> <span class="o">-</span><span class="n">r</span> <span class="mi">30</span>
<span class="n">vals</span><span class="p">,</span> <span class="n">vecs</span> <span class="o">=</span> <span class="n">sla</span><span class="p">.</span><span class="n">eigh_tridiagonal</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">e</span><span class="p">,</span> <span class="n">select</span><span class="o">=</span><span class="s">'i'</span><span class="p">,</span> <span class="n">select_range</span><span class="o">=</span><span class="p">(</span><span class="mi">50</span><span class="p">,</span> <span class="mi">50</span><span class="p">))</span>
<span class="n">vecs</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>6.68 ms ± 105 µs per loop (mean ± std. dev. of 30 runs, 100 loops each)
</code></pre></div></div>
<p>Between 5x and 6x speedup! Speed is not all we need - we also need representation power, which we shall explore in the next section.</p>

<h1 id="tridiagonal-eigenvalue-functions">Tridiagonal eigenvalue functions</h1>

<p>In the last post, we saw that we can re-write our eigenvalue models as optimization problems over quadratic functions:</p>

\[f({\boldsymbol x};{\boldsymbol A}_{0:n}) = \max_{ {\boldsymbol C} \in \mathbb{R}^{(k-1)\times d}} \min_{ {\boldsymbol u} \in \mathbb{R}^d} \left\{ {\boldsymbol u}^T \mathcal{A}(\boldsymbol x) {\boldsymbol u} : \| {\boldsymbol u} \|_2 = 1, \, {\boldsymbol C}{\boldsymbol u} = {\boldsymbol 0}\right\},\]

<p>where</p>

\[\mathcal{A}(\boldsymbol x) = {\boldsymbol A}_0 + \sum_{i=1}^n x_i {\boldsymbol A}_i.\]

<p>So we have a <em>latent variable</em> \(\boldsymbol u\) that appears in the quadratic function \({\boldsymbol u}^T \mathcal{A}(\boldsymbol x) \boldsymbol u\), which expresses interactions between <em>all entry pairs</em> \(\boldsymbol u\), since:</p>

\[{\boldsymbol u}^T \mathcal{A}(\boldsymbol x) \boldsymbol u = \sum_{i=1}^d \sum_{j=1}^d \bigl(\mathcal{A}(\boldsymbol x)\bigr)_{i,j} u_i u_j\]

<p>If \(\mathcal{A}(\boldsymbol x)\) were diagonal, we would lose all interactions - each entry \(u_i\) interacts only with itself:</p>

\[{\boldsymbol u}^T \mathcal{A}(\boldsymbol x) \boldsymbol u = \sum_{i=1}^d \bigl(\mathcal{A}(\boldsymbol x)\bigr)_{i,i} u_i^2.\]

<p>This is another manifestation of the loss of expressiveness we discussed before. But if it were tri-diagonal, we do have pairwise interactions:</p>

\[{\boldsymbol u}^T \mathcal{A}(\boldsymbol x) \boldsymbol u = \sum_{i=1}^d \bigl(\mathcal{A}(\boldsymbol x)\bigr)_{i,i} u_i^2 + 2 \sum_{i=1}^{d-1} \bigl(\mathcal{A}(\boldsymbol x)\bigr)_{i+1,i} u_{i+1} u_{i}.\]

<p>Even though it’s only between adjacent pairs \(u_i\) and \(u_{i+1}\), it turns out to be enough to produce a fairly rich set of models. Note, these are pairwise interactions between entries of the latent variable \(\boldsymbol u\), not of the raw features \(\boldsymbol x\). In fact, <em>all features</em> of \(\boldsymbol x\) potentially with each other, since each entry of \(\mathcal{A}(\boldsymbol x)\) contains a linear combination of all features.</p>

<p>To visually see that we have nontrivial expressive power, let’s try plotting a univariate function:</p>

\[f_k(x) = \lambda_k(\boldsymbol A + x \boldsymbol B),\]

<p>where the two matrices are tridiagonal, meaning specified by their diagonal and off-diagonal vectors. Here is a simple implementation of \(f_k(x)\) above:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">tridiagonal_eig_1d</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">diag</span><span class="p">,</span> <span class="n">off_diag</span><span class="p">,</span> <span class="n">xs</span><span class="p">):</span>
    <span class="sa">r</span><span class="s">"""Univariate matrix pencil eigenvalue.
        f(x) = \lambda_k(A + x B)
    where A and B are both tridiagonal.

    Args:
        k (int): the eigenvalue index
        diag (array): a 2 x n array of the diagonals of  A and B
        off_diag (array): a 2 x (n - 1) array of the off-diagonals of A and B
        xs (array): a vector of values x to evaluate f(x) at.

    Returns:
        An array y with y[i] = f(x[i])
    """</span>

    <span class="n">padded_xs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">c_</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">xs</span><span class="p">),</span> <span class="n">xs</span><span class="p">]</span>
    <span class="n">mat_diag</span> <span class="o">=</span> <span class="n">padded_xs</span> <span class="o">@</span> <span class="n">diag</span>         <span class="c1"># m x n
</span>    <span class="n">mat_off_diag</span> <span class="o">=</span> <span class="n">padded_xs</span> <span class="o">@</span> <span class="n">off_diag</span> <span class="c1"># m x (n - 1)
</span>    <span class="n">eigval</span> <span class="o">=</span> <span class="n">sla</span><span class="p">.</span><span class="n">eigvalsh_tridiagonal</span><span class="p">(</span>
        <span class="n">mat_diag</span><span class="p">,</span> <span class="n">mat_off_diag</span><span class="p">,</span> <span class="n">select</span><span class="o">=</span><span class="s">'i'</span><span class="p">,</span> <span class="n">select_range</span><span class="o">=</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span>
    <span class="p">)</span>
    <span class="k">return</span> <span class="n">eigval</span>
</code></pre></div></div>

<p>Let’s try plotting a function obtained from random \(5 \times 5\) matrices. Below is a function that plots a grid of eigenvalue functions \(f_k(x)\) for all \(k\), followed by its use to plot our functions:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">math</span>


<span class="k">def</span> <span class="nf">plot_tridiag_eig_1d</span><span class="p">(</span><span class="n">diag</span><span class="p">,</span> <span class="n">off_diag</span><span class="p">,</span> <span class="n">xmin</span><span class="o">=-</span><span class="mi">3</span><span class="p">,</span> <span class="n">xmax</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">resolution</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">fn</span><span class="o">=</span><span class="n">tridiagonal_eig_1d</span><span class="p">):</span>
    <span class="n">dim</span> <span class="o">=</span> <span class="n">diag</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    <span class="n">n_rows</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">dim</span><span class="p">))</span>
    <span class="n">n_cols</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">dim</span> <span class="o">/</span> <span class="n">n_rows</span><span class="p">))</span>
    <span class="n">xs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">resolution</span><span class="p">)</span>
    <span class="n">fig</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="s">'constrained'</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">dim</span><span class="p">),</span> <span class="n">axs</span><span class="p">.</span><span class="n">ravel</span><span class="p">()):</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">fn</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">diag</span><span class="p">,</span> <span class="n">off_diag</span><span class="p">,</span> <span class="n">xs</span><span class="p">))</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">'$</span><span class="se">\\</span><span class="s">lambda</span><span class="si">{</span><span class="mi">1</span><span class="o">+</span><span class="n">i</span><span class="si">}</span><span class="s">$'</span><span class="p">)</span>
    <span class="n">fig</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>


<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">plot_tridiag_eig_1d</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_tridiag_5x5.png" alt="pow_spec_tridiag_5x5" /></p>

<p>Alright! We see functions having non-trivial shapes. As expected from what we saw in previous posts, the smallest eigenvalue \(\lambda_1\) is concave, the largest \(\lambda_5\) is convex, and all other eigenvalue functions have piecewise-smooth shapes that are neigher convex nor concave. What about \(11\times 11\) matrices?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plot_tridiag_eig_1d</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">11</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_tridiag_11x11.png" alt="pow_spec_tridiag_11x11" /></p>

<p>As expected, larger random matrices produce “wilder” shapes - the set of the functions is richer.</p>

<p>Now that we’ve convinced ourselves that tridiagonal matrices have some potential, as a family providing a reasonable balance between speed and expressiveness, let’s move on to a more convincing demonstration of that potential.</p>

<h1 id="training-tridiagonal-matrix-eigenvalue-models">Training tridiagonal matrix eigenvalue models</h1>

<p>If we want to be able to train with PyTorch, we first need to make sure we can enjoy fast tridiagonal eigenvalue computation there as well. Unfortunately, as of now (PyTorch 2.10), we do <em>not</em> have fast tridiagonal eigenvalue routines in PyTorch, even though tridiagonal and banded matrices do appear in many scientific computing domains. So similarly to a <a href="/2026/01/20/Spectrum-Speed.html">previous post</a>, we will have to implement a custom autograd function that will forward PyTorch tensors to SciPy routines.</p>

<p>As a reminder - we need to subclass <code class="language-plaintext highlighter-rouge">torch.autograd.Function</code> and implement two static methods - <code class="language-plaintext highlighter-rouge">forward</code> for the computation and <code class="language-plaintext highlighter-rouge">backward</code> for the back-propagation of derivatives. This is exactly where we need eigenvectors, and not only the eigenvalues, as we explained in this previous <a href="/2026/01/20/Spectrum-Speed.html">post</a> in the series. As a reminder, for the function \(\lambda_k(\boldsymbol X)\), the “right kind” of generalized derivative is the matrix \(\boldsymbol q_k \boldsymbol q_k^T\), where \(\boldsymbol q_k\) is the corresponding eigenvector. When \(\boldsymbol X\) is tridiagonal, we just need the diagonal and off-diagonal vectors of the \(\boldsymbol q_k \boldsymbol q_k^T\).</p>

<p>So below an autograd function implementing exactly this idea. It appears a bit lengthy, but that’s primarily because it aims to be efficient, and distinguish between two cases: (a) when we need derivatives, e.g., training, and require an eigenvector, and (b) when we do not need derivatives, e.g., inference, and do <em>not</em> require an eigenvector:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>

<span class="k">class</span> <span class="nc">TridiagEigvalsh</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">autograd</span><span class="p">.</span><span class="n">Function</span><span class="p">):</span>
    <span class="o">@</span><span class="nb">staticmethod</span>
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">diag</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">off_diag</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="s">"""Eigenvalue of batch of tridiagonal matrices.

        Args:
            diag (tensor): A M1 x ... x Mn x N tensor representing a batch
                of size M1 x ... x Mn of diagonals of NxN tridiagonal symmetric
                matrices.
            off_diag (tensor): A M1 x ... x Mn x (N - 1) tensor representing
                a batch of size M1 x ... x Mn of off-diagonals of NxN
                tridiagonal symmetric matrices.
            k (int): The eigenvalue index
        """</span>
        <span class="n">need_grad</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">needs_input_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">or</span> <span class="n">ctx</span><span class="p">.</span><span class="n">needs_input_grad</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>

        <span class="n">diag_np</span> <span class="o">=</span> <span class="n">diag</span><span class="p">.</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">()</span>
        <span class="n">off_diag_np</span> <span class="o">=</span> <span class="n">off_diag</span><span class="p">.</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">()</span>
        <span class="k">if</span> <span class="n">need_grad</span><span class="p">:</span>
            <span class="c1"># k-th eigenvalue and eigenvector
</span>            <span class="n">ws_np</span><span class="p">,</span> <span class="n">Qs_np</span> <span class="o">=</span> <span class="n">sla</span><span class="p">.</span><span class="n">eigh_tridiagonal</span><span class="p">(</span>
                <span class="n">diag_np</span><span class="p">,</span> <span class="n">off_diag_np</span><span class="p">,</span> <span class="n">select</span><span class="o">=</span><span class="s">'i'</span><span class="p">,</span> <span class="n">select_range</span><span class="o">=</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">),</span>
                <span class="n">lapack_driver</span><span class="o">=</span><span class="s">"stemr"</span>
            <span class="p">)</span>
            <span class="n">ws</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">ws_np</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">diag</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
            <span class="n">Qs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">Qs_np</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">diag</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
            <span class="n">ctx</span><span class="p">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">Qs</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="c1"># only k-th eigenvalue
</span>            <span class="n">ws_cp</span> <span class="o">=</span> <span class="n">sla</span><span class="p">.</span><span class="n">eigvalsh_tridiagonal</span><span class="p">(</span>
                <span class="n">diag_np</span><span class="p">,</span> <span class="n">off_diag_np</span><span class="p">,</span> <span class="n">select</span><span class="o">=</span><span class="s">'i'</span><span class="p">,</span> <span class="n">select_range</span><span class="o">=</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span>
            <span class="p">)</span>
            <span class="n">ws</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">ws_cp</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">diag</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">ws</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># k-th eigenvalue
</span>
    <span class="o">@</span><span class="nb">staticmethod</span>
    <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">grad_w</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
        <span class="p">(</span><span class="n">Qs</span><span class="p">,)</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">saved_tensors</span>  <span class="c1"># (..., N) from SciPy
</span>
        <span class="n">grad_w</span> <span class="o">=</span> <span class="n">grad_w</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">Qs</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>                 <span class="c1"># (...)
</span>        <span class="n">gw</span> <span class="o">=</span> <span class="n">grad_w</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>                          <span class="c1"># (..., 1)
</span>
        <span class="n">grad_diag</span> <span class="o">=</span> <span class="n">gw</span> <span class="o">*</span> <span class="n">Qs</span><span class="p">.</span><span class="n">square</span><span class="p">()</span>                       <span class="c1"># (..., N)
</span>        <span class="n">grad_off</span>  <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">gw</span> <span class="o">*</span> <span class="p">(</span><span class="n">Qs</span><span class="p">[...,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">Qs</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">:])</span>  <span class="c1"># (..., N-1)
</span>
        <span class="k">return</span> <span class="n">grad_diag</span><span class="p">,</span> <span class="n">grad_off</span><span class="p">,</span> <span class="bp">None</span>
</code></pre></div></div>

<p>Now let’s try it out. Here is a batch of diagonals of 50 matrices of size 100x100:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">diags</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">50</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">off_diags</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">50</span><span class="p">,</span> <span class="mi">99</span><span class="p">)</span>
</code></pre></div></div>

<p>Now let’s try applying our PyTorch function to the raw tensors. Note - they do not require a gradient, since they aren’t trainable parameters, so we’re going through the <code class="language-plaintext highlighter-rouge">no_grad</code> path:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">timeit</span> <span class="o">-</span><span class="n">r</span> <span class="mi">30</span> <span class="o">-</span><span class="n">n</span> <span class="mi">100</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">TridiagEigvalsh</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">diags</span><span class="p">,</span> <span class="n">off_diags</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>3.71 ms ± 91.7 µs per loop (mean ± std. dev. of 30 runs, 100 loops each)
</code></pre></div></div>
<p>Whoa! That’s fast! Now let’s try doing it with trainable parameters:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">diags_param</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">diags</span><span class="p">)</span>
<span class="n">off_diags_param</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">off_diags</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">timeit</span> <span class="o">-</span><span class="n">r</span> <span class="mi">30</span> <span class="o">-</span><span class="n">n</span> <span class="mi">100</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">TridiagEigvalsh</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">diags_param</span><span class="p">,</span> <span class="n">off_diags_param</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
<span class="n">w</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>4.88 ms ± 205 µs per loop (mean ± std. dev. of 30 runs, 100 loops each)
</code></pre></div></div>
<p>Pretty fast - a mini-batch of 50 tridiagonal matrices of size 100x100 can compute gradients in approximately 5 milliseconds. Comparing it with approximately 35 milliseconds for full dense matrices - quite a speedup. For convenience, let’s wrap our autograd function class with a simple Python function:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">tridiag_eigvalsh</span><span class="p">(</span>
        <span class="n">diag</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">off_diag</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="nb">int</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">:</span>
    <span class="k">return</span> <span class="n">TridiagEigvalsh</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">diag</span><span class="p">,</span> <span class="n">off_diag</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span>
</code></pre></div></div>

<p>So now, to train a model we need a torch module representing our \(f(\boldsymbol x, \boldsymbol A_{0..n})\) for the tri-diagonal case. This means our trainable parameters are the diagonals and the off-diagonals of the matrices \(\boldsymbol A_0, \dots, \boldsymbol A_n\). Note that both the diagonal vector and the off-diagonal vector of \(\mathcal{A}(\boldsymbol x)\) are just linear functions of \(\boldsymbol x\), so we can express them as simple <code class="language-plaintext highlighter-rouge">torch.nn.Linear</code> layers. This yields an almost magically simple class:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="k">class</span> <span class="nc">TridiagSpectral</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">num_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">eig_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">eig_idx</span> <span class="o">=</span> <span class="n">eig_idx</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">diag</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">off_diag</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">tridiag_eigvalsh</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="bp">self</span><span class="p">.</span><span class="n">off_diag</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="bp">self</span><span class="p">.</span><span class="n">eig_idx</span><span class="p">)</span>
</code></pre></div></div>

<p>Now we can use it for training, like any PyTorch model. So let’s try learning a classifier that detects whether we have either two or five ones in a vector:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">toy_function</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span>
        <span class="n">x</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span>
        <span class="n">x</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="mi">5</span>
    <span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
</code></pre></div></div>

<p>Now we shall apply it to learning this function over 12-dimensional vectors. So let’s generate all binary vectors and compute their true label:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_features</span> <span class="o">=</span> <span class="mi">12</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cartesian_prod</span><span class="p">(</span><span class="o">*</span><span class="p">([</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">])]</span> <span class="o">*</span> <span class="n">n_features</span><span class="p">))</span> 
<span class="n">y</span> <span class="o">=</span> <span class="n">toy_function</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
</code></pre></div></div>

<p>This set should contain \(2^12 = 4096\) vectors. And before training, let’s divide the features and labels into a train and evaluation set:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">train_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">X</span><span class="p">))</span> <span class="o">&lt;</span> <span class="mf">0.5</span>
<span class="n">X_train</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">train_mask</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">y_train</span> <span class="o">=</span> <span class="n">y</span><span class="p">[</span><span class="n">train_mask</span><span class="p">]</span>
<span class="n">X_test</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="o">~</span><span class="n">train_mask</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">y_test</span> <span class="o">=</span> <span class="n">y</span><span class="p">[</span><span class="o">~</span><span class="n">train_mask</span><span class="p">]</span>
</code></pre></div></div>

<p>Alright! We’re ready to train a classifier on <code class="language-plaintext highlighter-rouge">(X_train, y_train)</code> and evaluate it on <code class="language-plaintext highlighter-rouge">(X_test, y_test)</code>. This would be a good time to introduce the <a href="https://github.com/alexshtf/fitstream/">fitstream</a> library, which is very convenient for training PyTorch models on small in-memory datasets. Recall that we found it very convenient to hide the training loop behind a <em>Python generator</em> that yields an event on every epoch. So this is what this library does - it performs a pretty standard PyTorch training loop, and yields a dict with some data at the end of each epoch. Let’s first install it in our notebook:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="n">fitstream</span>
</code></pre></div></div>

<p>Now let’s use it. Below is a short snippet demonstrating how we iterate over the first 3 events, which are simple Python dicts, and use Python’s <a href="https://docs.python.org/3/library/pprint.html">pprint</a> library to nicely print each dict:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">fitstream</span> <span class="k">as</span> <span class="n">fts</span>
<span class="kn">from</span> <span class="nn">pprint</span> <span class="kn">import</span> <span class="n">pprint</span>

<span class="c1"># define model and optimizer
</span><span class="n">dim</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">TridiagSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">n_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">eig_idx</span><span class="o">=</span><span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-1</span><span class="p">)</span>

<span class="c1"># use FitStream to obtain the event generator
</span><span class="n">events</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">epoch_stream</span><span class="p">(</span>
    <span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="n">model</span><span class="p">,</span> <span class="n">optim</span><span class="p">,</span> <span class="n">nn</span><span class="p">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span>
<span class="p">)</span>

<span class="c1"># iterate over the first three events
</span><span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">event</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span> <span class="n">events</span><span class="p">):</span>
    <span class="n">pprint</span><span class="p">(</span><span class="n">event</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="s">'---'</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>{'model': TridiagSpectral(
  (diag): Linear(in_features=12, out_features=3, bias=True)
  (off_diag): Linear(in_features=12, out_features=2, bias=True)
),
 'step': 1,
 'train_loss': 0.47607117891311646,
 'train_time_sec': 0.19717628799844533}
---
{'model': TridiagSpectral(
  (diag): Linear(in_features=12, out_features=3, bias=True)
  (off_diag): Linear(in_features=12, out_features=2, bias=True)
),
 'step': 2,
 'train_loss': 0.4695621430873871,
 'train_time_sec': 0.17527212999993935}
---
{'model': TridiagSpectral(
  (diag): Linear(in_features=12, out_features=3, bias=True)
  (off_diag): Linear(in_features=12, out_features=2, bias=True)
),
 'step': 3,
 'train_loss': 0.46067821979522705,
 'train_time_sec': 0.1757020229997579}
---
</code></pre></div></div>
<p>Now we see what we get - each dict contains our model, the epoch index in the <code class="language-plaintext highlighter-rouge">step</code> key, the training loss, and the training time in seconds. Pretty minimal, so the library comes with some helper functions to take this minimal event stream, and enrich it.</p>

<p>It has the <code class="language-plaintext highlighter-rouge">pipe</code> function, which lets us pipe the event stream through a sequence of transformations. So let’s introduce the first transformation - <code class="language-plaintext highlighter-rouge">take</code>, which simply takes the head of the event stream of the specified size. For example, this will produce 5 events:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># define model and optimizer
</span><span class="n">dim</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">TridiagSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">n_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">eig_idx</span><span class="o">=</span><span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-1</span><span class="p">)</span>

<span class="c1"># pipe the epoch stream through the "take" transformation
</span><span class="n">events</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">pipe</span><span class="p">(</span>
    <span class="n">fts</span><span class="p">.</span><span class="n">epoch_stream</span><span class="p">(</span>
        <span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="n">model</span><span class="p">,</span> <span class="n">optim</span><span class="p">,</span> <span class="n">nn</span><span class="p">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span>
    <span class="p">),</span>
    <span class="n">fts</span><span class="p">.</span><span class="n">take</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="p">)</span>

<span class="k">for</span> <span class="n">event</span> <span class="ow">in</span> <span class="n">events</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="n">event</span><span class="p">[</span><span class="s">'step'</span><span class="p">],</span> <span class="n">event</span><span class="p">[</span><span class="s">'train_loss'</span><span class="p">])</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>1 0.4920884072780609
2 0.46535375714302063
3 0.4354994297027588
4 0.4041599929332733
5 0.32225537300109863
</code></pre></div></div>

<p>The second important transformation is <code class="language-plaintext highlighter-rouge">augment</code>, which adds additional keys to each event. Here is an example of adding a key with the training loss squared:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># define model and optimizer
</span><span class="n">dim</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">TridiagSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">n_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">eig_idx</span><span class="o">=</span><span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-1</span><span class="p">)</span>

<span class="c1"># pipe the epoch stream through the "take" transformation
</span><span class="n">events</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">pipe</span><span class="p">(</span>
    <span class="n">fts</span><span class="p">.</span><span class="n">epoch_stream</span><span class="p">(</span>
        <span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="n">model</span><span class="p">,</span> <span class="n">optim</span><span class="p">,</span> <span class="n">nn</span><span class="p">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span>
    <span class="p">),</span>
    <span class="n">fts</span><span class="p">.</span><span class="n">augment</span><span class="p">(</span><span class="k">lambda</span> <span class="n">event</span><span class="p">:</span> <span class="p">{</span><span class="s">"loss_squared"</span><span class="p">:</span> <span class="n">event</span><span class="p">[</span><span class="s">"train_loss"</span><span class="p">]</span> <span class="o">**</span> <span class="mi">2</span><span class="p">}),</span>
    <span class="n">fts</span><span class="p">.</span><span class="n">take</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="p">)</span>

<span class="k">for</span> <span class="n">event</span> <span class="ow">in</span> <span class="n">events</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="n">event</span><span class="p">[</span><span class="s">'step'</span><span class="p">],</span> <span class="n">event</span><span class="p">[</span><span class="s">"train_loss"</span><span class="p">],</span> <span class="n">event</span><span class="p">[</span><span class="s">'loss_squared'</span><span class="p">])</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>1 0.4896900951862335 0.23979638932350245
2 0.45949652791023254 0.21113705916155912
3 0.46307340264320374 0.2144369762355547
4 0.44300273060798645 0.19625141932613221
5 0.4168809950351715 0.1737897640215147
</code></pre></div></div>

<p>Now, this is all nice, but the library is richer. It comes with augmentations for adding a validation loss, early stopping, or simply executing a function on every event. So here is a generator of events for a full-fledged training procedure with a validation loss and early stopping that also prints the losses every 10 epochs:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">dim</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">TridiagSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">n_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">eig_idx</span><span class="o">=</span><span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-1</span><span class="p">)</span>
<span class="n">events</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">pipe</span><span class="p">(</span>
    <span class="n">fts</span><span class="p">.</span><span class="n">epoch_stream</span><span class="p">(</span>
        <span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="n">model</span><span class="p">,</span> <span class="n">optim</span><span class="p">,</span> <span class="n">nn</span><span class="p">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span>
    <span class="p">),</span>
    <span class="n">fts</span><span class="p">.</span><span class="n">augment</span><span class="p">(</span><span class="n">fts</span><span class="p">.</span><span class="n">validation_loss</span><span class="p">((</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">),</span> <span class="n">nn</span><span class="p">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">())),</span>
    <span class="n">fts</span><span class="p">.</span><span class="n">early_stop</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="s">"val_loss"</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">5</span><span class="p">),</span>
    <span class="n">fts</span><span class="p">.</span><span class="n">tap</span><span class="p">(</span><span class="n">fts</span><span class="p">.</span><span class="n">print_keys</span><span class="p">(</span><span class="s">"train_loss"</span><span class="p">,</span> <span class="s">"val_loss"</span><span class="p">),</span> <span class="n">every</span><span class="o">=</span><span class="mi">10</span><span class="p">),</span>
    <span class="n">fts</span><span class="p">.</span><span class="n">take</span><span class="p">(</span><span class="mi">100</span><span class="p">)</span>
<span class="p">)</span>
</code></pre></div></div>
<p>Finally, the library comes with a set of <em>collector</em> functions that iterate over the events and collect them into various data structures. Here it will be convenient to use <code class="language-plaintext highlighter-rouge">collect_pd</code>, which collects the event dicts into a Pandas DataFrame. So here is an example of collecting the above event stream into a data frame, and then plotting the training and validation losses:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">training_log</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">collect_pd</span><span class="p">(</span><span class="n">events</span><span class="p">)</span>
<span class="n">training_log</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="s">"step"</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="p">[</span><span class="s">"train_loss"</span><span class="p">,</span> <span class="s">"val_loss"</span><span class="p">],</span> <span class="n">title</span><span class="o">=</span><span class="s">'Dim = 3'</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>step=0001 train_loss=0.4908 val_loss=0.4681
step=0011 train_loss=0.3369 val_loss=0.3147
step=0021 train_loss=0.1540 val_loss=0.1516
step=0031 train_loss=0.0768 val_loss=0.0926
step=0041 train_loss=0.0500 val_loss=0.0725
step=0051 train_loss=0.0363 val_loss=0.0621
step=0061 train_loss=0.0291 val_loss=0.0580
step=0071 train_loss=0.0241 val_loss=0.0557
step=0081 train_loss=0.0210 val_loss=0.0538
step=0091 train_loss=0.0186 val_loss=0.0536
step=0101 train_loss=0.0172 val_loss=0.0539
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_tridiag_toy_3.png" alt="pow_spec_tridiag_toy_3" /></p>

<p>Nice! We see that the model is learning, which is encouraging. Before we do more experiments, let’s write a function that will construct such a training procedure for us and collect the events to a dataframe:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">run_experiment_bce</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-1</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">max_epochs</span><span class="o">=</span><span class="mi">100</span><span class="p">):</span>
    <span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span>
    <span class="n">events</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">pipe</span><span class="p">(</span>
        <span class="n">fts</span><span class="p">.</span><span class="n">epoch_stream</span><span class="p">(</span>
            <span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="n">model</span><span class="p">,</span> <span class="n">optim</span><span class="p">,</span> <span class="n">nn</span><span class="p">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span>
        <span class="p">),</span>
        <span class="n">fts</span><span class="p">.</span><span class="n">augment</span><span class="p">(</span><span class="n">fts</span><span class="p">.</span><span class="n">validation_loss</span><span class="p">((</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">),</span> <span class="n">nn</span><span class="p">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">())),</span>
        <span class="n">fts</span><span class="p">.</span><span class="n">early_stop</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="s">"val_loss"</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">5</span><span class="p">),</span>
        <span class="n">fts</span><span class="p">.</span><span class="n">take</span><span class="p">(</span><span class="n">max_epochs</span><span class="p">)</span>
    <span class="p">)</span>
    <span class="k">return</span> <span class="n">fts</span><span class="p">.</span><span class="n">collect_pd</span><span class="p">(</span><span class="n">events</span><span class="p">)</span>
</code></pre></div></div>
<p>Now we can easily plot similar losses for \(5 \times 5\) tridiagonal matrices:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">training_log_5</span> <span class="o">=</span> <span class="n">run_experiment_bce</span><span class="p">(</span>
    <span class="n">TridiagSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">n_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">eig_idx</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">training_log_5</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="s">"step"</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="p">[</span><span class="s">"train_loss"</span><span class="p">,</span> <span class="s">"val_loss"</span><span class="p">],</span> <span class="n">title</span><span class="o">=</span><span class="s">'Dim = 5'</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_tridiag_toy_5.png" alt="pow_spec_tridiag_toy_5" /></p>

<p>Much better! Now the validation loss is very close to zero as well. Now let’s move to \(9 \times 9\) matrices:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">training_log_9</span> <span class="o">=</span> <span class="n">run_experiment_bce</span><span class="p">(</span>
    <span class="n">TridiagSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">n_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">9</span><span class="p">,</span> <span class="n">eig_idx</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">training_log_9</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="s">"step"</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="p">[</span><span class="s">"train_loss"</span><span class="p">,</span> <span class="s">"val_loss"</span><span class="p">],</span> <span class="n">title</span><span class="o">=</span><span class="s">'Dim = 9'</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="https://alexshtf.github.io/assets/pow_spec_tridiag_toy_9.png" alt="pow_spec_tridiag_toy_9" /></p>

<p>Beautiful! Apparently, a model with \(9 \times 9\) tridiagonal symmetric matrices, which has \(13 \times (9 + 8) = 221\) parameters, can learn this function from data almost perfectly. And conceptually, this is just a linear function of the features followed by a non-linear function - the matrix eigenvalue. Just one neuron! You can try it, but a “classical” neuron cannot learn this function.</p>

<p>So now that we’re convinced that the machinery is working, let’s try it on the dataset that accompanies this series - the California Housing dataset we have built into our Colab notebooks.</p>

<h1 id="california-housing-training">California housing training</h1>

<p>Recall that the dataset is about predicting housing prices in California based on some features. I will skip the part where we read the data, normalize features and targets, and split the data into training and test sets. We’ve already done it in previous posts in this series, and the notebook contains the full code. So here we’ll assume our training data is in <code class="language-plaintext highlighter-rouge">X_train, y_train</code>, our evaluation set is <code class="language-plaintext highlighter-rouge">X_test, y_test</code>, and the number of features is in <code class="language-plaintext highlighter-rouge">num_features</code>. Moreover, since our labels are scaled, we also have <code class="language-plaintext highlighter-rouge">label_scale</code>, which is the factor that transforms the training / eval RMSE back to the original units in the dataset - dollars.</p>

<p>First, let’s define a simple function that computes the RMSE in dollars:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">scaled_rmse</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
    <span class="n">mse</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">functional</span><span class="p">.</span><span class="n">mse_loss</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">mse</span><span class="p">)</span> <span class="o">*</span> <span class="n">label_scale</span>
</code></pre></div></div>

<p>Now, we can define a full-fledged training procedure with the FitStream library we just introduced. When experimenting, I noticed that learning rate scheduling improves convergence substantially and I can work with less epochs, so I also used a learning rate scheduler with warmup - just like we do with LLMs. It first increases the learning rate for a few epochs (warmup), and then decreases it slowly towards zero (cooldown). It is implemented in the <code class="language-plaintext highlighter-rouge">OneCycleLR</code> class from PyTorch. So here is our full training procedure:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">torch.optim.lr_scheduler</span> <span class="kn">import</span> <span class="n">OneCycleLR</span>

<span class="k">def</span> <span class="nf">complete_training_stream</span><span class="p">(</span>
        <span class="n">dim</span><span class="p">,</span> <span class="n">n_epochs</span><span class="p">,</span> <span class="n">warmup_fraction</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">5e-3</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span>
    <span class="p">):</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">TridiagSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">eig_idx</span><span class="o">=</span><span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span>
    <span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span>
    <span class="n">sched</span> <span class="o">=</span> <span class="n">OneCycleLR</span><span class="p">(</span>
        <span class="n">optim</span><span class="p">,</span> <span class="n">max_lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">total_steps</span><span class="o">=</span><span class="n">n_epochs</span><span class="p">,</span> <span class="n">pct_start</span><span class="o">=</span><span class="n">warmup_fraction</span><span class="p">,</span> <span class="n">anneal_strategy</span><span class="o">=</span><span class="s">'linear'</span>
    <span class="p">)</span>

    <span class="n">epoch_events</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">epoch_stream</span><span class="p">((</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="n">model</span><span class="p">,</span> <span class="n">optim</span><span class="p">,</span> <span class="n">nn</span><span class="p">.</span><span class="n">MSELoss</span><span class="p">(),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">fts</span><span class="p">.</span><span class="n">pipe</span><span class="p">(</span>
        <span class="n">epoch_events</span><span class="p">,</span>
        <span class="n">fts</span><span class="p">.</span><span class="n">take</span><span class="p">(</span><span class="n">n_epochs</span><span class="p">),</span>
        <span class="n">fts</span><span class="p">.</span><span class="n">augment</span><span class="p">(</span><span class="n">fts</span><span class="p">.</span><span class="n">validation_loss</span><span class="p">((</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">),</span> <span class="n">scaled_rmse</span><span class="p">)),</span> <span class="c1"># &lt;-- here we use scaled_rmse
</span>        <span class="n">fts</span><span class="p">.</span><span class="n">augment</span><span class="p">(</span><span class="k">lambda</span> <span class="n">event</span><span class="p">:</span> <span class="p">{</span><span class="s">"lr"</span><span class="p">:</span> <span class="n">optim</span><span class="p">.</span><span class="n">param_groups</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="s">'lr'</span><span class="p">]}),</span>
        <span class="n">fts</span><span class="p">.</span><span class="n">early_stop</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="s">"val_loss"</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="n">n_epochs</span> <span class="o">//</span> <span class="mi">10</span><span class="p">),</span>
        <span class="n">fts</span><span class="p">.</span><span class="n">tick</span><span class="p">(</span><span class="n">sched</span><span class="p">.</span><span class="n">step</span><span class="p">),</span>
    <span class="p">)</span>
</code></pre></div></div>
<p>Note where we use our <code class="language-plaintext highlighter-rouge">scaled_rmse</code> - it is inserted as the validation loss to the stream. Now, let’s try it out with 11-dimensional matrices for 20 epochs:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">training_log</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">collect_pd</span><span class="p">(</span><span class="n">complete_training_stream</span><span class="p">(</span><span class="mi">11</span><span class="p">,</span> <span class="mi">20</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="n">training_log</span><span class="p">)</span>
</code></pre></div></div>
<p>This is what I got:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>step  train_loss  train_time_sec       val_loss        lr
0      1    0.893857        1.095248  101073.101562  0.000200
1      2    0.404809        1.114020   65408.109375  0.005000
2      3    0.298268        1.138942   62557.234375  0.004722
3      4    0.277326        1.107545   61244.425781  0.004444
4      5    0.268057        1.099500   60413.164062  0.004167
5      6    0.262378        1.105777   59918.421875  0.003889
6      7    0.256854        1.100250   59919.507812  0.003611
7      8    0.253308        1.095398   59363.906250  0.003333
8      9    0.250837        1.091576   59104.089844  0.003056
9     10    0.248961        1.083466   59065.980469  0.002778
10    11    0.246233        1.092109   58896.339844  0.002500
11    12    0.244495        1.114625   58766.113281  0.002222
12    13    0.241918        1.118336   58610.593750  0.001944
13    14    0.240912        1.101395   58488.941406  0.001667
14    15    0.239700        1.100444   58310.894531  0.001389
15    16    0.238621        1.096240   58618.683594  0.001111
16    17    0.237742        1.085707   58579.175781  0.000833
</code></pre></div></div>
<p>We can see the model is training, the learning rate increased in the first two epochs, as expected, since 10% of the epochs are warmup. It stopped after 17 epochs due to the early stopping mechanism whose patience is two epochs (again, 10% of the maximum).</p>

<p>We can also write a nice function for plotting the learning rate and the validation loss. It’s a bit of boilerplate, for using the primary y-axis for the validation loss, and the secondary y-axis for the learning rate.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">plot_log</span><span class="p">(</span><span class="n">log</span><span class="p">,</span> <span class="n">title</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
    <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>

    <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">log</span><span class="p">.</span><span class="n">step</span><span class="p">,</span> <span class="n">log</span><span class="p">.</span><span class="n">val_loss</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'blue'</span><span class="p">,</span>
            <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">'RMSE (best=$</span><span class="si">{</span><span class="n">log</span><span class="p">.</span><span class="n">val_loss</span><span class="p">.</span><span class="nb">min</span><span class="p">()</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">)'</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"Error ($)"</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">grid</span><span class="p">()</span>

    <span class="n">lr_ax</span> <span class="o">=</span> <span class="n">ax</span><span class="p">.</span><span class="n">twinx</span><span class="p">()</span>
    <span class="n">lr_ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">log</span><span class="p">.</span><span class="n">step</span><span class="p">,</span> <span class="n">log</span><span class="p">.</span><span class="n">lr</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'Learning rate'</span><span class="p">,</span>
               <span class="n">color</span><span class="o">=</span><span class="s">'gray'</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'dotted'</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">lr_ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"Learning rate"</span><span class="p">)</span>

    <span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
    <span class="k">if</span> <span class="n">title</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="n">title</span><span class="p">)</span>
    <span class="n">fig</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>

<span class="n">plot_log</span><span class="p">(</span><span class="n">training_log</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="https://alexshtf.github.io/assets/pow_spectrum_tri_calhousing_11_20.png" alt="pow_spectrum_tri_calhousing_11_20" /></p>

<p>In blue we see the validation loss, whereas in dotted black we see the learning rate. We can nicely see the warmup and cooldown stages.</p>

<p>Alright, so now that we have all the machinery in place, let’s try training some model with more epochs. I used 500 epochs in all the experiments, which was enough to train both smaller and larger models. So let’s try 7-dimensional tridiagonal matrices:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">training_log_7</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">collect_pd</span><span class="p">(</span><span class="n">complete_training_stream</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">500</span><span class="p">))</span>
<span class="n">plot_log</span><span class="p">(</span><span class="n">training_log_7</span><span class="p">,</span> <span class="n">title</span><span class="o">=</span><span class="s">'Dim=7'</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="https://alexshtf.github.io/assets/pow_spectrum_tri_calhousing_7_300.png" alt="pow_spectrum_tri_calhousing_7_300" /></p>

<p>How about 11-dimensional tridiagonals?</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">training_log_11</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">collect_pd</span><span class="p">(</span><span class="n">complete_training_stream</span><span class="p">(</span><span class="mi">11</span><span class="p">,</span> <span class="mi">500</span><span class="p">))</span>
<span class="n">plot_log</span><span class="p">(</span><span class="n">training_log_11</span><span class="p">,</span> <span class="n">title</span><span class="o">=</span><span class="s">'Dim=11'</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="https://alexshtf.github.io/assets/pow_spectrum_tri_calhousing_11_300.png" alt="pow_spectrum_tri_calhousing_11_300" /></p>

<p>Nice! Increasing matrix size reduces the error, meaning that performance scales with model size. But remember - it is just one neuron! How about \(15 \times 15\) matrices?</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">training_log_15</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">collect_pd</span><span class="p">(</span><span class="n">complete_training_stream</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">500</span><span class="p">))</span>
<span class="n">plot_log</span><span class="p">(</span><span class="n">training_log_15</span><span class="p">,</span> <span class="n">title</span><span class="o">=</span><span class="s">'Dim=15'</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="https://alexshtf.github.io/assets/pow_spectrum_tri_calhousing_15_300.png" alt="pow_spectrum_tri_calhousing_15_300" /></p>

<p>Another slight improvement. What about \(45 \times 45\) matrices?</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">training_log_45</span> <span class="o">=</span> <span class="n">fts</span><span class="p">.</span><span class="n">collect_pd</span><span class="p">(</span><span class="n">tqdm</span><span class="p">(</span><span class="n">complete_training_stream</span><span class="p">(</span><span class="mi">45</span><span class="p">,</span> <span class="mi">500</span><span class="p">)))</span>
<span class="n">plot_log</span><span class="p">(</span><span class="n">training_log_45</span><span class="p">,</span> <span class="n">title</span><span class="o">=</span><span class="s">'Dim=45'</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="https://alexshtf.github.io/assets/pow_spectrum_tri_calhousing_45_300.png" alt="pow_spectrum_tri_calhousing_45_300" /></p>

<p>I can share, and you can see it by running the notebook yourself, that each such experiment takes 3-4 minutes. Just to get a feeling - compared to dense matrix experiments we conducted in previous posts, this is much faster, and without any GPU. I’m pretty sure that if PyTorch had tridiagonal support, we could have run each experiment in seconds. But unfortunately - it does not.</p>

<p>Comparing it to dense experiments we conducted with the same dataset and similar matrix sizes in <a href="/2026/01/20/Spectrum-Speed.html">this post</a>, which took us 31 minutes on an NVidia L4 GPU for a \(45 \times 45\) matrix, while achieving a similar test error - we clearly see the difference. No GPU, an order of magnitude faster, and a similar performance at least on this dataset.</p>

<p>Of course - the above are not proper experiments I’d include in a paper. I haven’t conducted any hyperparameter search, perhaps a different optimizer could be better, etc…, but we see the point.</p>

<h1 id="summary">Summary</h1>

<p>To summarize, we can see that restricting ourselves to eigenvalue model families where all matrices are simultaneously tri-diagonalizable can be useful to strike a good balance between speed and expressiveness. Let us recall why this model family is interesting - it’s just one neuron, a linear (matrix) function composed with a non-linearity, that is quite expressive, while being fairly interpretable. These nice properties haven’t gone anywhere - spectral norms of our tridiagonal matrices are still a reasonable way to think of importance, and provide a certificate for sensitivity of the model to changes in that feature.</p>

<p>We do, however, see slow convergence. 500 epochs is quite a lot, and even though our training procedure stops beforehand due to the early stopping mechanism, it’s still a few hundred epochs. Even if I throw the best practices at it, such as learning rate scheduling, early stopping, and others - it’s still quite slow. At this stage, this is a price we pay for having a model that is, on the one hand, just one fairly interpretable neuron, but on the other hand can be improved by scaling.</p>

<p>We have many more questions to explore in this series. For example - can we prune any dense eigenvalue model to tridiagonal form? Can we make it converge faster? How do we stack such models as layers of a larger neural network? Stay tuned!</p>]]></content><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><category term="machine learning" /><category term="eigenvalue models" /><category term="spectral methods" /><category term="tridiagonal matrices" /><category term="structured matrices" /><category term="numerical linear algebra" /><category term="pytorch" /><category term="scipy" /><category term="autograd" /><summary type="html"><![CDATA[Cheaper eigenvalue training and inference with symmetric tridiagonal matrices: preserve useful expressiveness, use fast SciPy-backed PyTorch autograd, and avoid dense eigensolvers.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="https://alexshtf.github.io/assets/pow_spectrum_tri_calhousing_45_300.png" /><media:content medium="image" url="https://alexshtf.github.io/assets/pow_spectrum_tri_calhousing_45_300.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Interpreting eigenvalue models</title><link href="https://alexshtf.github.io/2026/02/03/Spectrum-Interpretation.html" rel="alternate" type="text/html" title="Interpreting eigenvalue models" /><published>2026-02-03T00:00:00+00:00</published><updated>2026-02-03T00:00:00+00:00</updated><id>https://alexshtf.github.io/2026/02/03/Spectrum-Interpretation</id><content type="html" xml:base="https://alexshtf.github.io/2026/02/03/Spectrum-Interpretation.html"><![CDATA[<h1 id="intro">Intro</h1>

<p>We continue our discussion of machine-learned models of the form</p>

\[f({\boldsymbol x};{\boldsymbol A}_{0:n}) = \lambda_k \Bigl({\boldsymbol A}_0 + \sum_{i=1}^n x_i {\boldsymbol A}_i\Bigr),\]

<p>where \({\boldsymbol A}_i\) are learned symmetric matrices, and \(\lambda_k\) is the \(k\)-th smallest eigenvalue. We touched on one aspect of interpretability in a previous post - the importance of each of the features \(x_1, \dots, x_n\). But there are other aspects: what does this model actually compute? How can we reason about it, or explain it to our colleagues or to a regulator? We began this series by saying that this is a “neuron” that is solving an optimization problem. So in this post we shall focus on the different kinds of optimization problems that \(f({\boldsymbol x}; {\boldsymbol A}_{0:n})\) solves, what interpretations we can give to them, and why we should care. All the results in this post are based on Chapter 4 of the <em>Matrix Analysis</em> book by Horn &amp; Johnson, 2nd edition.</p>

<h1 id="a-game-between-two-players">A game between two players</h1>

<p>You’ve probably seen eigenvalues presented as “stretch factors”. But for our understanding of the model, the optimization view is often more useful. So here is a Courant min-max characterization of the \(k\)-th smallest eigenvalue, named after Richard Courant. It’s a bit “hairy”, so let’s first present it, and then interpret it:</p>

\[\lambda_k({\boldsymbol A}) = \max_{ {\boldsymbol C} \in \mathbb{R}^{(k-1)\times d}} \min_{ {\boldsymbol u} \in \mathbb{R}^d} \left\{ {\boldsymbol u}^T {\boldsymbol A} {\boldsymbol u} : \| {\boldsymbol u} \|_2 = 1, \, {\boldsymbol C}{\boldsymbol u} = {\boldsymbol 0}\right\}\]

<p>This is a bi-level optimization problem, which we can think of as a two-turn game. The first player chooses \({\boldsymbol C}\) with \(k-1\) rows (i.e., \(k-1\) linear constraints). In response, the second player chooses a unit vector \({\boldsymbol u}\) that is in the null-space of \({\boldsymbol C}\), or equivalently, <em>orthogonal</em> to the rows of \({\boldsymbol C}\).</p>

<p>The objective of the second player is to pay as little as possible, where \({\boldsymbol u}^T {\boldsymbol A} {\boldsymbol u}\) is the cost. But the objective of the first player, of course, is to make their opponent pay as much as possible, so they choose a “worst case” matrix \({\boldsymbol C}\). In case you were wondering, when \(k = 1\) we have no adversarial player, and the vector \(\boldsymbol u\) can be an arbitrary unit vector. And you have probably guessed: at any equilibrium, \({\boldsymbol u}\) is an eigenvector corresponding to \(\lambda_k({\boldsymbol A})\).</p>

<p>One way to read this is: \({\boldsymbol u}\) is a bounded allocation over \(d\) latent resources, and \(A_{i,j}\) is the cost associated with every pairwise interaction of resources, since:</p>

\[{\boldsymbol u}^T {\boldsymbol A} {\boldsymbol u} = \sum_{i=1}^d \sum_{j=1}^d A_{i,j} u_i u_j\]

<p>Depending on the context, you can give different interpretations. For example, \({\boldsymbol A}\) represents a set of latent skills a student possesses, and \({\boldsymbol u}\) is a <em>test vector</em> for pairs of skills.</p>

<p>There is a mirror-image of the above game, if we want to sort eigenvalues from largest to smallest. So the \(k\)-th <em>largest</em> eigenvalue, which is also the \(d-k+1\)-th smallest one, can be written as</p>

\[\lambda_{d-k+1}({\boldsymbol A}) = \min_{ {\boldsymbol C} \in \mathbb{R}^{(k-1)\times d}} \max_{ {\boldsymbol u} \in \mathbb{R}^d} \left\{ {\boldsymbol u}^T {\boldsymbol A} {\boldsymbol u} : \| {\boldsymbol u} \|_2 = 1, \, {\boldsymbol C}{\boldsymbol u} = {\boldsymbol 0}\right\}\]

<p>We can think of it in terms of utility rather than cost - each entry in the matrix is a utility associated with a pair of resources. The first player is choosing the matrix so that the second player will get as little utility as possible, whereas the second player, in response, aims to choose a vector \({\boldsymbol u}\) that will maximize their utility.</p>

<p>Adopting the \(\max-\min\) convention, we can write our “neuron” as:</p>

\[f({\boldsymbol x};{\boldsymbol A}_{0:n}) = \max_{ {\boldsymbol C} \in \mathbb{R}^{(k-1)\times d}} \min_{ {\boldsymbol u} \in \mathbb{R}^d} \left\{ {\boldsymbol u}^T \left({\boldsymbol A}_0 + \sum_{i=1}^n x_i {\boldsymbol A}_i \right) {\boldsymbol u} : \| {\boldsymbol u} \|_2 = 1, \, {\boldsymbol C}{\boldsymbol u} = {\boldsymbol 0}\right\}\]

<p>Consequently, each feature is associated with a matrix of some latent “costs” and our features are just the weights of these costs. Suppose you’re in the insurance business, and someone asks you what your model is doing. You can explain something like “oh, we’re representing each feature of the insured using a table of latent skills to avoid claims, we sum them up, and simulate a game where we aim to elicit their worst-case ability to either avoid damage or absorb it without claiming”. You can give an example of what those “latent skills” could be, just like in matrix factorization people explain what would be the latent features in a movie recommendation system.</p>

<h1 id="as-a-kind-of-recurrent-neural-network">As a (kind of) recurrent neural network</h1>

<p>Another way to characterize symmetric matrix eigenvalues is as a <em>sequence</em> of optimization problems. Here is a Courant-Fischer formulation (Fischer here is Ernst Sigismund Fischer, not Ronald Fisher):</p>

<p>We just saw that the smallest eigenvalue is just the minimum of a quadratic function over the unit sphere. The second eigenvalue is similar, but the vector has to be orthogonal to <em>one row</em>, since the matrix \(\boldsymbol C\) the first player chooses has one row. The next eigenvalue should be orthogonal to <em>two rows</em>, and so on. But we can actually be more precise about what these constraints are. One of the possible formulations of the Courant-Fischer theorem is</p>

<blockquote>
  <p>Let \(\boldsymbol A\) be a symmetric matrix with eigenvalues \(\lambda_1 \leq \lambda_2 \leq \cdots \leq \lambda_d\) with corresponding eigenvectors \({\boldsymbol u}_1, \dots, {\boldsymbol u}_d\). Then,</p>

\[\lambda_k = \min_{\boldsymbol u} \left\{ {\boldsymbol u}^T {\boldsymbol A} {\boldsymbol u} : \| {\boldsymbol u} \|_2 = 1,\,\langle {\boldsymbol u}, {\boldsymbol u}_1 \rangle = 0, \dots, \langle {\boldsymbol u}, {\boldsymbol u}_{k-1} \rangle = 0 \right\}\]
</blockquote>

<p>In other words, the \(k\)-th smallest eigenvalue is the minimum of \({\boldsymbol u}^T {\boldsymbol A} {\boldsymbol u}\) among all unit vectors orthogonal to eigenvectors corresponding to the previous eigenvalues. Thus, we can think of it as a recurrent process: computing each eigenvalue yields an eigenvector, and all eigenvectors up to \(k-1\) are used to compute the \(k\)-th eigenvalue. Visually, it looks like this:</p>

<p><img src="https://alexshtf.github.io/assets/pow_spec_recurrent.png" alt="" /></p>

<p>All steps share the same matrix \({\boldsymbol A}\) as their weights, just like recurrent neural networks share weights in each recurrent step.</p>

<p>Intuitively, as we move from \(\lambda_1\) upward, the function becomes more and more expressive. But it is tempting (and wrong) to conclude that the largest eigenvalue is the most expressive function of \({\boldsymbol A}\). Indeed, we can construct a “mirror-image” of the above process if we order the eigenvalues in decreasing order. In practice, the richest behavior tends to come from the middle of the spectrum, as we saw in the plots.</p>

<p>Again, we can think of this process as a kind of a repeated “game”. This time there is only one player. In each turn the player aims to minimize their cost, but their “strategy” \({\boldsymbol u}\) in each turn becomes more and more restricted - they must try something “different”, or orthogonal to, the strategies they chose in the previous turns.</p>

<h1 id="as-a-difference-of-convex-functions">As a difference of convex functions</h1>

<p>Minimization of nonconvex functions is a long-standing challenge in optimization theory and practice. But sometimes knowing some additional information about a nonconvex function can substantially improve both the speed and the reliability of our ability to minimize it. One of these pieces of information is having an explicit representation of the function we aim to minimize (or maximize) as a <em>difference of convex (DC) functions</em>, namely,</p>

\[f({\boldsymbol x}) = g({\boldsymbol x}) - h({\boldsymbol x}),\]

<p>such that both \(g\) and \(h\) are convex. In fact, there is an entire stream of literature and algorithms on DC optimization, and many famous optimization software packages have dedicated code paths for this task. For example, the <a href="https://github.com/cvxgrp/dccp">DCCP extension</a> for <a href="https://www.cvxpy.org/">CVXPY</a> is a famous example.</p>

<p>Turns out our \(f({\boldsymbol x};{\boldsymbol A}_{0:n})\) has such an explicit representation and can be directly used in the DCCP extension, and other similar software packages. Why is it useful? Well, suppose \(f\) models the expected cost of some decision that we would like to minimize.</p>

<p>The idea is based on the Ky Fan Variational Principle, stating that the <em>sum</em> of the eigenvalues \(\lambda_k, \lambda_{k+1}, ..., \lambda_d\)  can be written as</p>

\[\Lambda_k({\boldsymbol A}) = \sum_{i=k}^d \lambda_i({\boldsymbol A}) = \max_{ {\boldsymbol U} \in \mathbb{R}^{(d-k+1) \times d}} \left\{ \operatorname{tr}({\boldsymbol U}^T {\boldsymbol A} {\boldsymbol U}) : {\boldsymbol U}^T {\boldsymbol U} = {\boldsymbol I}  \right\}\]

<p>This looks a bit hairy, but the term we are maximizing is  a <em>linear function</em> of \(\boldsymbol A\), even if it’s a nonlinear function of \(\boldsymbol U\). And the maximum of linear functions is always convex, even if we have an infinite number of linear functions. Consequently, the \(k\)-th smallest eigenvalue can be written in an explicit DC form as:</p>

\[\lambda_k({\boldsymbol A}) = \Lambda_k({\boldsymbol A}) - \Lambda_{k+1}({\boldsymbol A})\]

<p>We can see it visually by plotting \(\lambda_k({\boldsymbol P} + x {\boldsymbol Q})\) and its two convex components as a function of \(x\):</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="c1"># choose random matrices
</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">P</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="n">Q</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>

<span class="c1"># compute eigenvalues of A + x B for x in [-3, 3]
</span><span class="n">xs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
<span class="n">eigvals</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">P</span> <span class="o">+</span> <span class="n">xs</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">,</span> <span class="bp">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">Q</span><span class="p">)</span>

<span class="c1"># plot mid eigenvalue and its constituent convex functions
</span><span class="n">sum_top_3</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">eigvals</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">sum_top_2</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">eigvals</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">:],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">sum_top_3</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sa">r</span><span class="s">'$\Lambda_3$'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">sum_top_2</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sa">r</span><span class="s">'$\Lambda_4$'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">eigvals</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="sa">r</span><span class="s">'$\lambda_3$'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">fill_between</span><span class="p">(</span>
    <span class="n">xs</span><span class="p">,</span> <span class="n">sum_top_3</span><span class="p">,</span> <span class="n">sum_top_2</span><span class="p">,</span>
    <span class="n">color</span><span class="o">=</span><span class="s">'skyblue'</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="p">(</span><span class="n">sum_top_3</span> <span class="o">&gt;</span> <span class="n">sum_top_2</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">fill_between</span><span class="p">(</span>
    <span class="n">xs</span><span class="p">,</span> <span class="n">sum_top_3</span><span class="p">,</span> <span class="n">sum_top_2</span><span class="p">,</span>
    <span class="n">color</span><span class="o">=</span><span class="s">'red'</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="p">(</span><span class="n">sum_top_3</span> <span class="o">&lt;</span> <span class="n">sum_top_2</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">grid</span><span class="p">(</span><span class="bp">True</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_dc.png" alt="pow_spec_dc" /></p>

<p>Indeed, the orange and blue plots, which are top eigenvalue sums, are convex functions. The gap between them is red when \(\Lambda_3 - \Lambda_4\) is negative, and blue when it is positive. The function in green is the difference, and it exactly reflects the size and the sign of the gap.</p>

<h1 id="recap">Recap</h1>

<p>This was a theoretical detour, to understand what kind of functions are we fitting and how we can reason about them. We saw that we can interpret our “neuron” as a game between two players, as a recurrent process where each step solves a simple quadratic optimization problem, and as the difference between convex functions. In the next post we’re back to more practical stuff.</p>]]></content><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><category term="machine learning" /><category term="eigenvalue models" /><category term="spectral methods" /><category term="interpretability" /><category term="optimization" /><category term="minimax" /><category term="dc optimization" /><summary type="html"><![CDATA[Interpreting eigenvalue-based ML models: read the k-th eigenvalue as a two-player game using Courant principle, a sequential orthogonality process via Courant-Fischer theorem, and a difference-of-convex function using Ky Fan variational principle.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="https://alexshtf.github.io/assets/pow_spec_recurrent.png" /><media:content medium="image" url="https://alexshtf.github.io/assets/pow_spec_recurrent.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">I feel the need for Eigen-Speed</title><link href="https://alexshtf.github.io/2026/01/20/Spectrum-Speed.html" rel="alternate" type="text/html" title="I feel the need for Eigen-Speed" /><published>2026-01-20T00:00:00+00:00</published><updated>2026-01-20T00:00:00+00:00</updated><id>https://alexshtf.github.io/2026/01/20/Spectrum-Speed</id><content type="html" xml:base="https://alexshtf.github.io/2026/01/20/Spectrum-Speed.html"><![CDATA[<p align="center">
  <a href="https://colab.research.google.com/github/alexshtf/alexshtf.github.io/blob/master/assets/spectrum_power_speed.ipynb" target="_blank" rel="noopener">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" />
  </a>
</p>

<h1 id="intro">Intro</h1>

<p>Efficiency is a quite superpower in research. When training is fast, you can iterate: try an idea, get surprised, debug, tune, and move on. When training is slow, you start avoiding experiments you <em>should</em> be running, simply because they cost too much time.</p>

<p>This became very tangible in this eigenvalue-model series. In the previous posts we looked at models that predict the \(k\)-th eigenvalue of a learned symmetric matrix built from the input features, and we explored what they can represent, plus some robustness/interpretability properties. Naturally, the next step was to scale up the matrix size and see what happens in practice.</p>

<p>I did not show it in the last post, but if you take the California Housing experiment and run \(30\times 30\) matrices for 500 epochs, it takes <em>more than an hour</em> on Colab. And this is with an L4 GPU. I even tried an A100, and it didn’t meaningfully improve anything. So then I asked myself - what’s wrong?</p>

<p>Turns out PyTorch is wrong. <code class="language-plaintext highlighter-rouge">torch.linalg.eigvalsh</code> has a note in the <a href="https://docs.pytorch.org/docs/2.9/generated/torch.linalg.eigvalsh.html">official documentation</a>: when the input is on CUDA, it synchronizes the device with the CPU. If eigenvalues sit inside your inner training loop, that synchronization becomes the bottleneck, and the rest of your model almost doesn’t matter.</p>

<p>So this post is a practical detour: we’ll make eigenvalue computation fast enough that it stops getting in the way. We’ll replace the slow call with a faster GPU implementation, and we’ll wrap it in a way that still supports backprop through the \(k\)-th eigenvalue. Once that’s done, the scaling experiments from the previous post become feasible again, and we can go back to asking the interesting questions.</p>

<p>(All execution speeds I measure in this post are on Colab, with an NVIDIA L4 GPU, with the 2025.10 runtime.)</p>

<h1 id="warm-up">Warm-up</h1>

<p>Let’s start from a simple eigenvalue computation test on the CPU. We’ll use Jupyter’s <code class="language-plaintext highlighter-rouge">%%time</code> magic keyword to measure time. First, let’s create a mini-batch of 500 matrices of size 100x100:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">mats</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
</code></pre></div></div>

<p>Now let’s see how fast we can compute the sum of eigenvalues of all these matrices:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">time</span>
<span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">mats</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>CPU times: user 223 ms, sys: 14.4 ms, total: 237 ms
Wall time: 188 ms
tensor(56.6666)
</code></pre></div></div>

<p>And now with NumPy:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="c1"># in another cell
</span><span class="o">%%</span><span class="n">time</span>
<span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">mats</span><span class="p">.</span><span class="n">numpy</span><span class="p">()).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>CPU times: user 942 ms, sys: 7.07 ms, total: 949 ms
Wall time: 337 ms

np.float32(56.665924)
</code></pre></div></div>

<p>Apparently, on the CPU, NumPy is almost twice slower than PyTorch. So when our tensors are on the CPU, we can continue using PyTorch - it’s pretty fast.</p>

<p>Now let’s move to the GPU. Here is a similar piece of code - create a mini-batch of random matrices and compute the sum of their eigenvalues:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">mats</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s">'cuda'</span><span class="p">)</span>
<span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">mats</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>

<p>The reason I ran the eigenvalue computation once is as a “warmup” - I want PyTorch to do whatever setup it needs to run CUDA kernels, so next time we invoke <code class="language-plaintext highlighter-rouge">eigvalsh</code> it is going to be a “clean” run not contaminated by setup:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">time</span>
<span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">mats</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>CPU times: user 788 ms, sys: 811 µs, total: 789 ms
Wall time: 788 ms
tensor(-557.2836, device='cuda:0')
</code></pre></div></div>

<p>Whoa! It’s <em>twice</em> slower than NumPy on CPU, and <em>four times</em> slower than PyTorch on the CPU! Turns out PyTorch developers haven’t invested that much in general-purpose scientific computing on the GPU. It’s quite reasonable - it is not their main focus. So if we want to propose a new computational tool - it’s up to us to make it efficient!</p>

<p>So maybe PyTorch hasn’t invested in eigenvalues on GPU that much, but it doesn’t mean other scientific computing libraries haven’t. CuPy, a library aiming to be “NumPy on CUDA”, is one of those libraries that has a very fast eigenvalue solver we can use. But how can we use it on PyTorch tensors?</p>

<p>Turns out there is a standard called <a href="https://github.com/dmlc/dlpack">DLPack</a> for representing multi-dimensional tensors in memory, and it is supported both by PyTorch and by CuPy. In PyTorch we have the <code class="language-plaintext highlighter-rouge">torch.utils.dlpack</code> package for converting a tensor to a DLPack “capsule” - a wrapper around its memory with appropriate metadata:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">torch.utils</span> <span class="kn">import</span> <span class="n">dlpack</span> <span class="k">as</span> <span class="n">torch_dlpack</span>

<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="s">'cuda'</span><span class="p">)</span>
<span class="n">torch_dlpack</span><span class="p">.</span><span class="n">to_dlpack</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;capsule object "dltensor" at 0x7b9a1ae26760&gt;
</code></pre></div></div>

<p>We can use CuPy to consume this “capsule” and access <em>the same tensor</em>, but this time as a CuPy array:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">cupy</span> <span class="k">as</span> <span class="n">cp</span>

<span class="n">cp</span><span class="p">.</span><span class="n">from_dlpack</span><span class="p">(</span><span class="n">torch_dlpack</span><span class="p">.</span><span class="n">to_dlpack</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>array([ 1, -2,  3])
</code></pre></div></div>

<p>Now let’s try computing the sum of eigenvalues of our PyTorch <code class="language-plaintext highlighter-rouge">mats</code> tensor containing the mini-batch of matrices using CuPy:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">cupy_mats</span> <span class="o">=</span> <span class="n">cp</span><span class="p">.</span><span class="n">from_dlpack</span><span class="p">(</span><span class="n">torch_dlpack</span><span class="p">.</span><span class="n">to_dlpack</span><span class="p">(</span><span class="n">mats</span><span class="p">))</span>
<span class="n">cp</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">cupy_mats</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>array(-557.284, dtype=float32)
</code></pre></div></div>

<p>We got the same value, so apparently it’s working. Let’s time it:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">time</span>
<span class="n">cp</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">cupy_mats</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>CPU times: user 1.58 ms, sys: 0 ns, total: 1.58 ms
Wall time: 1.37 ms
array(-557.284, dtype=float32)
</code></pre></div></div>

<p>Now that’s FAST! 1.37 milliseconds, instead of more than 700 - almost 500 times faster! It’s impressive, but a part of the enormous speedup is because we aren’t doing anything to be prepared to backpropagate.</p>

<p>Of course we got a CuPy array of eigenvalues. But we can easily use DLPack to convert it back to a PyTorch tensor. Since there is no memory copy, it practically incurs no cost, as you can see below:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">time</span>
<span class="n">eigvals_cupy</span> <span class="o">=</span> <span class="n">cp</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">cupy_mats</span><span class="p">)</span>
<span class="n">torch_dlpack</span><span class="p">.</span><span class="n">from_dlpack</span><span class="p">(</span><span class="n">eigvals_cupy</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>CPU times: user 1.66 ms, sys: 0 ns, total: 1.66 ms
Wall time: 1.31 ms
tensor(-557.2838, device='cuda:0')
</code></pre></div></div>

<p>Here, the eigenvalues were computed with CuPy, but their sum was computed with PyTorch. You can see that we got the same result at the same speed, since DLPack conversions just wrap the same GPU memory block, without any copies.</p>

<p>So the process is simple:</p>

<ol>
  <li>Wrap our PyTorch tensor’s memory as a CuPy array via DLPack</li>
  <li>Compute eigenvalues using CuPy</li>
  <li>Convert eigenvalues back to PyTorch via DLPack</li>
</ol>

<p>There are more nuances here about memory management, and which object is responsible for actually freeing the memory when it’s no longer needed, and you should learn these DLPack nuances if you wish to use it. But that’s out of the scope of this post.</p>

<p>What we have is still not enough to build a full-fledged function we can use for model training in PyTorch, since for training we also need <em>gradients</em>.</p>

<h1 id="eigenvalue-gradients">Eigenvalue gradients</h1>

<p>Consider the function</p>

\[f({\boldsymbol X}) = \lambda_k({\boldsymbol X})\]

<p>of a symmetric matrix \(\boldsymbol X\). Recall from linear algebra that eigenvalues are roots of polynomials, and polynomial roots can have <em>multiplicities</em> - the same root can “repeat” multiple times.</p>

<p>For simplicity, assume for now that at our point of interest we have a simple eigenvalue, namely, with multiplicity 1. In this case, a well-known result from linear algebra is that it has a <em>unique</em> (up to sign) normalized eigenvector \({\boldsymbol q}_k({\boldsymbol X})\). Turns out that the function \(f\) is <em>differentiable</em> at such points, and the gradient is simple:</p>

\[\nabla f({\boldsymbol X}) = {\boldsymbol q}_k({\boldsymbol X}) {\boldsymbol q}_k({\boldsymbol X})^T.\]

<p>Thus, the only thing we need for back-propagation is the <em>eigenvector</em> corresponding to our desired eigenvalue - its outer product with itself is the gradient.</p>

<p>Let’s convince ourselves that this works with code. Here is the outer product of the eigenvector corresponding to the middle eigenvalue of a \(5 \times 5\) matrix with itself:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">mat</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">25</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="n">w</span><span class="p">,</span> <span class="n">Q</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigh</span><span class="p">(</span><span class="n">mat</span><span class="p">)</span>
<span class="n">i</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">grad_mat</span> <span class="o">=</span> <span class="n">Q</span><span class="p">[:,</span> <span class="n">i</span><span class="p">].</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">@</span> <span class="n">Q</span><span class="p">[:,</span> <span class="n">i</span><span class="p">].</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">grad_mat</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([[ 0.1389, -0.2269, -0.0168,  0.2359, -0.1102],
        [-0.2269,  0.3708,  0.0274, -0.3855,  0.1801],
        [-0.0168,  0.0274,  0.0020, -0.0285,  0.0133],
        [ 0.2359, -0.3855, -0.0285,  0.4008, -0.1873],
        [-0.1102,  0.1801,  0.0133, -0.1873,  0.0875]])
</code></pre></div></div>

<p>And here is the gradient computed by taking the mid eigenvalue and applying <code class="language-plaintext highlighter-rouge">tensor.backward()</code> to compute the gradient:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">mat_param</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">25</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">mat_param</span><span class="p">)</span>
<span class="n">w</span><span class="p">[</span><span class="n">i</span><span class="p">].</span><span class="n">backward</span><span class="p">()</span>
<span class="n">mat_param</span><span class="p">.</span><span class="n">grad</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([[ 0.1389, -0.2269, -0.0168,  0.2359, -0.1102],
        [-0.2269,  0.3708,  0.0274, -0.3855,  0.1801],
        [-0.0168,  0.0274,  0.0020, -0.0285,  0.0133],
        [ 0.2359, -0.3855, -0.0285,  0.4008, -0.1873],
        [-0.1102,  0.1801,  0.0133, -0.1873,  0.0875]])
</code></pre></div></div>

<p>Things become more complicated when the eigenvalue is not simple, and has a multiplicity of at least two. In this case the function is <em>not</em> differentiable, and this is exactly the cause of the “kinks” we saw in the first post in the series, where we aimed to understand what kind of functions are representable using our “eigenvalue neuron”.</p>

<p>There are many notions of “generalized derivatives”, and we will have to choose one that is appropriate. Now here is a spoiler alert - we can still take one of the eigenvectors, call it \({\boldsymbol q}_k\), and use the vector \({\boldsymbol q}_k {\boldsymbol q}_k^T\) for back-propagation. So now that we know what code to write, let’s try to understand <em>why</em>.</p>

<p>Consider the well-known ReLU function with a kink at zero. To the left of zero, the derivative is zero. To the right of zero, it is one. At zero there is no derivative, but we can use any number between zero and one. Intuitively, we understand it’s because any line with a slope between zero and one can behave like a tangent - it touches the function at one point. Now note one important point - I said <em>any</em> number between zero and one. So we don’t have one slope we can use - we have an infinity of them.</p>

<p>A generalization of this idea of using the set of vectors “in between neighboring gradients” is known as the Clarke sub-differential<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>.  In higher dimensions, “in-between” generalizes to the closure of the convex hull. I am not going deep into theory, so we’ll not discuss exactly the convex hull of <em>what</em> we are taking, but intuitively these are gradients in a small neighborhood. If you’re interested, I have a great book<sup id="fnref:5" role="doc-noteref"><a href="#fn:5" class="footnote" rel="footnote">2</a></sup> by Frank Clarke himself to recommend :)</p>

<p>Clarke sub-differential is one of these notions of generalized derivatives that are typically accepted as the “right” one for back-prop <sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">3</a></sup><sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">4</a></sup>. We are not always guaranteed to get an element in the Clarke sub-differential<sup id="fnref:3:1" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">4</a></sup> when backpropagating through a large graph, but we should do our best at least for our atomic building blocks.  And just like we can take any slope between 0 and 1 for ReLU, we can take <em>any</em> vector in sub-differential set. Turns out our outer product of an eigenvector with itself is an element of the Clarke sub-differential set for the \(k\)-th eigenvalue function.</p>

<p>Now we have our two ingredients - a way to quickly compute eigenvalues and eigenvectors on the GPU, and a way to compute the gradient for backpropagation - so let’s finally create our PyTorch function!</p>

<h1 id="a-custom-k-th-eigenvalue-function">A custom \(k\)-th eigenvalue function</h1>

<p>First, we’ll need two utilities to convert tensors from PyTorch to CuPy and back via DLPack:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">_torch_to_cupy</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
    <span class="s">""" Zero-copy via DLPack for CUDA """</span>
    <span class="k">return</span> <span class="n">cp</span><span class="p">.</span><span class="n">from_dlpack</span><span class="p">(</span><span class="n">torch_dlpack</span><span class="p">.</span><span class="n">to_dlpack</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>

<span class="k">def</span> <span class="nf">_cupy_to_torch</span><span class="p">(</span><span class="n">x_cupy</span><span class="p">):</span>
    <span class="s">""" Zero-copy via DLPack for CUDA """</span>
    <span class="k">return</span> <span class="n">torch_dlpack</span><span class="p">.</span><span class="n">from_dlpack</span><span class="p">(</span><span class="n">x_cupy</span><span class="p">)</span>
</code></pre></div></div>

<p>Implementing a custom PyTorch autograd function is quite simple - we just need to follow a template. We inherit from <code class="language-plaintext highlighter-rouge">torch.autograd.Function</code> and implement two static methods - <code class="language-plaintext highlighter-rouge">forward</code> and <code class="language-plaintext highlighter-rouge">backward</code>. The former computes our function, and optionally caches anything required for computing the derivative. The latter just back-propagates the derivative. Moreover, to make things efficient, typically <code class="language-plaintext highlighter-rouge">forward</code> is split into two code paths - one efficient path when no derivatives are required (inference mode), and another one for the case when derivatives are required. So here it is:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">CuPyKthEigval</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">autograd</span><span class="p">.</span><span class="n">Function</span><span class="p">):</span>
    <span class="o">@</span><span class="nb">staticmethod</span>
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">A</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lower</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">True</span><span class="p">):</span>
        <span class="c1"># A: PyTorch --&gt; CuPy
</span>        <span class="n">A_</span> <span class="o">=</span> <span class="n">A</span> <span class="k">if</span> <span class="n">A</span><span class="p">.</span><span class="n">is_contiguous</span><span class="p">()</span> <span class="k">else</span> <span class="n">A</span><span class="p">.</span><span class="n">contiguous</span><span class="p">()</span>
        <span class="n">A_cp</span> <span class="o">=</span> <span class="n">_torch_to_cupy</span><span class="p">(</span><span class="n">A_</span><span class="p">.</span><span class="n">detach</span><span class="p">())</span>

        <span class="c1"># Which part of A to use, in CuPy language
</span>        <span class="n">uplo</span> <span class="o">=</span> <span class="s">"L"</span> <span class="k">if</span> <span class="n">lower</span> <span class="k">else</span> <span class="s">"U"</span>

        <span class="k">if</span> <span class="n">ctx</span><span class="p">.</span><span class="n">needs_input_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span> <span class="c1"># for training
</span>            <span class="c1"># CuPy eigenvalues and eigenvectors
</span>            <span class="n">ws_cp</span><span class="p">,</span> <span class="n">Qs_cp</span> <span class="o">=</span> <span class="n">cp</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigh</span><span class="p">(</span><span class="n">A_cp</span><span class="p">,</span> <span class="n">UPLO</span><span class="o">=</span><span class="n">uplo</span><span class="p">)</span>
            
            <span class="c1"># CuPy --&gt; PyTorch
</span>            <span class="n">ws</span> <span class="o">=</span> <span class="n">_cupy_to_torch</span><span class="p">(</span><span class="n">ws_cp</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">A</span><span class="p">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">A</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>
            <span class="n">Qs</span> <span class="o">=</span> <span class="n">_cupy_to_torch</span><span class="p">(</span><span class="n">Qs_cp</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">A</span><span class="p">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">A</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>
            
            <span class="c1"># Store k-th eigenvector for the derivative
</span>            <span class="n">ctx</span><span class="p">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">Qs</span><span class="p">[...,</span> <span class="n">k</span><span class="p">].</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
        <span class="k">else</span><span class="p">:</span> <span class="c1"># for inference
</span>            <span class="n">ws_cp</span> <span class="o">=</span> <span class="n">cp</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">A_cp</span><span class="p">,</span> <span class="n">UPLO</span><span class="o">=</span><span class="n">uplo</span><span class="p">)</span>
            <span class="n">ws</span> <span class="o">=</span> <span class="n">_cupy_to_torch</span><span class="p">(</span><span class="n">ws_cp</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">A</span><span class="p">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">A</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">ws</span><span class="p">[...,</span> <span class="n">k</span><span class="p">]</span> <span class="c1"># k-th eigenvalue
</span>
    <span class="o">@</span><span class="nb">staticmethod</span>
    <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">grad_w</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
        <span class="p">(</span><span class="n">Q</span><span class="p">,)</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">saved_tensors</span>  <span class="c1"># (..., n, 1)
</span>        <span class="n">grad_w</span> <span class="o">=</span> <span class="n">grad_w</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">Q</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
        <span class="n">grad_A</span> <span class="o">=</span> <span class="p">(</span><span class="n">Q</span> <span class="o">*</span> <span class="n">grad_w</span><span class="p">[...,</span> <span class="bp">None</span><span class="p">,</span> <span class="bp">None</span><span class="p">])</span> <span class="o">@</span> <span class="n">Q</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">grad_A</span><span class="p">,</span> <span class="bp">None</span><span class="p">,</span> <span class="bp">None</span>  <span class="c1"># no grad for `k` and `lower`
</span></code></pre></div></div>

<p>It’s a bit lengthy, but straightforward. We just follow the sketch we laid about above. To use our new function, we just need to call the <code class="language-plaintext highlighter-rouge">apply</code> function of our new class:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">mat_param</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">25</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s">'cuda'</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">CuPyKthEigval</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">mat_param</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">w</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">mat_param</span><span class="p">.</span><span class="n">grad</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([[ 0.1389, -0.2269, -0.0168,  0.2359, -0.1102],
        [-0.2269,  0.3708,  0.0274, -0.3855,  0.1801],
        [-0.0168,  0.0274,  0.0020, -0.0285,  0.0133],
        [ 0.2359, -0.3855, -0.0285,  0.4008, -0.1873],
        [-0.1102,  0.1801,  0.0133, -0.1873,  0.0875]], device='cuda:0')
</code></pre></div></div>

<p>Identical to the gradient we previously obtained - so as a sanity check, this appears to be working. Now let’s measure the speed versus PyTorch with a mini-batch of 500 matrices of size \(50 \times 50\):</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">mats</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s">'cuda'</span><span class="p">)</span>
<span class="n">mat_param</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s">'cuda'</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">time</span> 
<span class="n">w</span> <span class="o">=</span> <span class="n">CuPyKthEigval</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">mats</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>CPU times: user 1.44 ms, sys: 993 µs, total: 2.43 ms
Wall time: 2.1 ms
</code></pre></div></div>
<p>Alright, 2 milliseconds. Pretty fast. What happens if we need backpropagation?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">time</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">CuPyKthEigval</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">mat_param</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="nb">sum</span><span class="p">()</span>
<span class="n">w</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>CPU times: user 81.3 ms, sys: 1 µs, total: 81.3 ms
Wall time: 81 ms
</code></pre></div></div>

<p>Much slower! But we also understand why - we need the eigenvector, not just the eigenvalue.  What about PyTorch?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">CPU</span> <span class="n">times</span><span class="p">:</span> <span class="n">user</span> <span class="mi">795</span> <span class="n">ms</span><span class="p">,</span> <span class="n">sys</span><span class="p">:</span> <span class="mi">0</span> <span class="n">ns</span><span class="p">,</span> <span class="n">total</span><span class="p">:</span> <span class="mi">795</span> <span class="n">ms</span>
<span class="n">Wall</span> <span class="n">time</span><span class="p">:</span> <span class="mi">794</span> <span class="n">ms</span>
</code></pre></div></div>

<p>Apparently, our custom function, even with backprop, is almost 10 times faster! Things appear much better, and we can move forward. As a final step, we now wrap it with a convenience function to separate the CUDA tensors from non-CUDA tensors:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">faster_kth_eigvalh</span><span class="p">(</span>
        <span class="n">A</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>  <span class="o">*</span><span class="p">,</span> <span class="n">lower</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">True</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">:</span>
    <span class="k">if</span> <span class="n">A</span><span class="p">.</span><span class="n">is_cuda</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">CuPyKthEigval</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">lower</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">lower</span><span class="p">)[...,</span> <span class="n">k</span><span class="p">]</span>
</code></pre></div></div>

<p>Nice! So now we have a function that works quickly on a GPU and we can finally do an experiment that I was not able to do in the previous post within a reasonable amount of time - try even larger matrices!</p>

<h1 id="trying-it-out-in-practice">Trying it out in practice</h1>

<p>Recall that in the last post we implemented a class, called <code class="language-plaintext highlighter-rouge">MultivariateSpectral</code>, for the model family we study in this series:</p>

\[f(\mathbf{x}; {\boldsymbol \mu}, \mathbf{A}_{1:n}) = \lambda_k \left( \operatorname{diag}({\boldsymbol \mu}) + \sum_{i=1}^n x_i \mathbf{A}_i \right),\]

<p>where the non-decreasing vector \({\boldsymbol \mu}\) and the symmetric matrices \(\mathbf{A}_1, \dots, \mathbf{A}_n\) are the learned parameters. Here is a version of it that uses our new <code class="language-plaintext highlighter-rouge">faster_kth_eigvalh</code> function:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MultivariateSpectral</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">num_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">eigval_idx</span> <span class="o">=</span> <span class="n">eigval_idx</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">mu</span> <span class="o">=</span> <span class="n">Nondecreasing</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="c1"># &lt;-- we wrote it in the last post
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">A</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span>
            <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">*</span> <span class="n">num_features</span><span class="p">)</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># batches of sum of x[i] * A[i]
</span>        <span class="n">nf</span><span class="p">,</span> <span class="n">dim</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span>
        <span class="n">feature_mat</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">@</span> <span class="bp">self</span><span class="p">.</span><span class="n">A</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="n">nf</span><span class="p">,</span> <span class="n">dim</span> <span class="o">*</span> <span class="n">dim</span><span class="p">)).</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>

        <span class="c1"># diag(mu) replicated per batch
</span>        <span class="n">bias_mat</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">diagflat</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">mu</span><span class="p">()).</span><span class="n">expand_as</span><span class="p">(</span><span class="n">feature_mat</span><span class="p">)</span>

        <span class="c1"># batched eigenvalue computation
</span>        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">_compute_eigval</span><span class="p">(</span><span class="n">bias_mat</span> <span class="o">+</span> <span class="n">feature_mat</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_compute_eigval</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mat</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">faster_kth_eigvalh</span><span class="p">(</span><span class="n">mat</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">eigval_idx</span><span class="p">)</span>
</code></pre></div></div>

<p>To be able to test it against PyTorch, here is a variant that uses the regular PyTorch eigenvalue function - we just inherit the above class and override the <code class="language-plaintext highlighter-rouge">_compute_eigval</code> function:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MultivariateSpectralTorch</span><span class="p">(</span><span class="n">MultivariateSpectral</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">_compute_eigval</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mat</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">mat</span><span class="p">)[...,</span> <span class="bp">self</span><span class="p">.</span><span class="n">eigval_idx</span><span class="p">]</span>
</code></pre></div></div>

<p>So now we will use functions we implemented in the last post to again test ourselves on supervised regression with the California Housing dataset.</p>

<p>In the last post we implemented the function <code class="language-plaintext highlighter-rouge">train_model_stream</code> that trains the given model and yields a sequence of dictionaries containing the model and the training loss, and the <code class="language-plaintext highlighter-rouge">add_spectral_norms</code> which augments this dictionary with spectral norms of the learned matrices that we used for obtaining a global bound on the model’s sensitivity with respect to features. Here we shall just use these helpers assuming they are defined, and that the dataset is already loaded and pre-processed. The linked notebook at the beginning of this post contains the full code.</p>

<p>So let’s measure how long 5 training epochs take with PyTorch eigenvalues. Again, we shall use the <code class="language-plaintext highlighter-rouge">%%time</code> Jupyter magic keyword to measure time:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">training_stream</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">n_epochs</span><span class="p">,</span> <span class="o">**</span><span class="n">train_kwargs</span><span class="p">):</span>
    <span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">MSELoss</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">add_spectral_norms</span><span class="p">(</span><span class="n">train_model_stream</span><span class="p">(</span>
        <span class="n">model</span><span class="p">,</span> <span class="n">criterion</span><span class="p">,</span> <span class="n">n_epochs</span><span class="o">=</span><span class="n">n_epochs</span><span class="p">,</span> <span class="o">**</span><span class="n">train_kwargs</span>
    <span class="p">))</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">time</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">MultivariateSpectralTorch</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">45</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="mi">22</span><span class="p">)</span>
<span class="k">for</span> <span class="n">event</span> <span class="ow">in</span> <span class="n">training_stream</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">n_epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">):</span>
    <span class="k">print</span><span class="p">(</span><span class="s">'tick'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tick
tick
tick
tick
tick
CPU times: user 1min 16s, sys: 466 ms, total: 1min 16s
Wall time: 1min 16s
</code></pre></div></div>

<p>Now let’s try it with our faster eigenvalue function:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">time</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">MultivariateSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">45</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="mi">22</span><span class="p">)</span>
<span class="k">for</span> <span class="n">event</span> <span class="ow">in</span> <span class="n">training_stream</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">n_epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">):</span>
    <span class="k">print</span><span class="p">(</span><span class="s">'tick'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tick
tick
tick
tick
tick
CPU times: user 15.7 s, sys: 205 ms, total: 15.9 s
Wall time: 16.1 s
</code></pre></div></div>

<p>Nice! Almost five times faster! So now I can actually conduct an experiment I could not in the previous post - see how the model scales if I increase matrix size even further, to \(45 \times 45\). To that end, we shall re-use the  <code class="language-plaintext highlighter-rouge">plot_progress</code> function that consumes such an iterable stream produced by training and produces a live-updating plot of the progress. Again - I assume the function is given, but you have the full code in the notebook.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">live_plot_training</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">n_epochs</span><span class="p">,</span> <span class="o">**</span><span class="n">train_kwargs</span><span class="p">):</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">MultivariateSpectral</span><span class="p">(</span>
        <span class="n">num_features</span><span class="o">=</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span>
    <span class="p">)</span>
    <span class="n">events</span> <span class="o">=</span> <span class="n">training_stream</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">n_epochs</span><span class="p">,</span> <span class="o">**</span><span class="n">train_kwargs</span><span class="p">)</span>
    <span class="n">plot_progress</span><span class="p">(</span>
        <span class="n">events</span><span class="p">,</span> <span class="n">max_step</span><span class="o">=</span><span class="n">n_epochs</span>
    <span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">time</span>
<span class="n">live_plot_training</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">45</span><span class="p">,</span> <span class="n">n_epochs</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">5e-5</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>CPU times: user 30min 45s, sys: 23.7 s, total: 31min 8s
Wall time: 31min 23s
</code></pre></div></div>

<p>Well, it took me half an hour. Quite long. But I was able to produce this plot:</p>

<p><img src="https://alexshtf.github.io/assets/pow_spec_props_norms_45.png" alt="pow_spec_props_norms_45" /></p>

<p>Recall that for a \(30 \times 30\) matrix, we got a test error of \(\approx \$54200\), so scaling up indeed improves performance somewhat, but not dramatically. Apparently, with our current training procedure we begin to notice the diminishing returns of this type of scaling.</p>

<p>Now, this does <em>not</em> mean that our training procedure is the best, and this is definitely not an exhaustive scaling experiment, where we choose the best training procedure we can, and perhaps devise some rule of hyperparameter transfer from smaller to larger models. But having the ability to compute eigenvalues quickly lets us actually conduct this research, since PyTorch eigenvalue solver was simply too slow.</p>

<h1 id="recap">Recap</h1>

<p>Now that we have the ability to conduct fast experiments we can move forward and do other interesting stuff. Obviously, there might be even better ways to achieve our goal - perhaps writing a custom CUDA kernel for the entire function \(f(\mathbf{x}; {\boldsymbol \mu}, \mathbf{A}_{1:n})\) would even be better. But I just wanted something that doesn’t get in my way when I’m experimenting - that’s all.</p>

<p>The next post in the series will be very different - it will be theoretical. We did a lot of practical things here, but we need to understand some things before we move forward. So stay tuned!</p>

<hr />

<p><strong>References</strong></p>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:1" role="doc-endnote">
      <p>Clarke, Frank H. “Generalized gradients and applications.” <em>Transactions of the American Mathematical Society</em> 205 (1975): 247-262. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:5" role="doc-endnote">
      <p>Clarke, Frank H. <em>Optimization and nonsmooth analysis</em>. Society for industrial and Applied Mathematics, 1990. <a href="#fnref:5" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>Park, Sejun, Sanghyuk Chun, and Wonyeol Lee. “What does automatic differentiation compute for neural networks?.” <em>The Twelfth International Conference on Learning Representations</em>. 2024. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:3" role="doc-endnote">
      <p>Bolte, Jérôme, Tam Le, and Edouard Pauwels. “Subgradient sampling for nonsmooth nonconvex minimization.” <em>SIAM Journal on Optimization</em> 33.4 (2023): 2542-2569. <a href="#fnref:3" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:3:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p>
    </li>
  </ol>
</div>]]></content><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><category term="machine learning" /><category term="eigenvalue models" /><category term="spectral methods" /><category term="pytorch" /><category term="cuda" /><category term="cupy" /><category term="dlpack" /><category term="autograd" /><summary type="html"><![CDATA[PyTorch eigenvalues on CUDA can be unexpectedly slow due to device synchronization. This post shows how to call CuPy via DLPack for fast GPU eigvalsh/eigh while keeping gradients for training.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="https://alexshtf.github.io/assets/pow_spec_props_norms_45.png" /><media:content medium="image" url="https://alexshtf.github.io/assets/pow_spec_props_norms_45.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Robustness, interpretability, and scaling of eigenvalue models</title><link href="https://alexshtf.github.io/2026/01/01/Spectrum-Props.html" rel="alternate" type="text/html" title="Robustness, interpretability, and scaling of eigenvalue models" /><published>2026-01-01T00:00:00+00:00</published><updated>2026-01-01T00:00:00+00:00</updated><id>https://alexshtf.github.io/2026/01/01/Spectrum-Props</id><content type="html" xml:base="https://alexshtf.github.io/2026/01/01/Spectrum-Props.html"><![CDATA[<p align="center">
  <a href="https://colab.research.google.com/github/alexshtf/alexshtf.github.io/blob/master/assets/spectrum_power_stability_robustness.ipynb" target="_blank" rel="noopener">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" />
  </a>
</p>

<h1 id="intro">Intro</h1>

<p>We all want our models to perform well. But some of us would also like our models to be efficient, robust, or interpretable. So in this post we will discuss some mathematical properties of these models that are related to these three pillars. Robustness and interpretability may mean different things to different people, so let’s explain what I mean in this post. As a general note - many things I am going to talk about are true for complex Hermitian matrices, but we focus on real symmetric matrices in the post. So this is the first and the last time I mention complex numbers in this series.</p>

<p>The robustness that we shall explore means robustness to <em>corruption</em> or <em>noise</em>, meaning that bounded changes to the input yield bounded changes to the output, and this bound is <em>known</em>. This is important when we want to know that a small perturbation will not make our model “go wild” and predict something totally unreasonable.</p>

<p>Interpretability can also mean many things. It can be interpretability for us, scientists, so that we can explain what the model does to ourselves. Alternatively, it can mean that we can explain what the model does to a business stakeholder or a regulator. Or in the extreme case, it means we can actually explain to a user why our system made the decision it made based on their input, i.e., why am I not getting a better insurance premium? In this post we shall mostly talk about the first two aspects.</p>

<p>But let’s get started with a small debt I believe I owe you from the previous post - eliminating some of the redundancy.</p>

<h1 id="eliminating-redundancy">Eliminating redundancy</h1>

<p>We defined our models as</p>

\[f(\mathbf{x}) = \lambda_k \left(\mathbf{A}_0 + \sum_{i=1}^n x_i \mathbf{A}_i \right)\]

<p>Some of us may remember from linear algebra that eigenvalues of symmetric matrices are <em>invariant</em> under orthogonal transformations. So the representation of our model is not unique - we can just replace all matrices \(\mathbf{A}_i\) by \(\mathbf{Q}\mathbf{A}_i\mathbf{Q}^\intercal\) for some orthogonal matrix \(\mathbf{Q}\) and obtain exactly the same model. Redundancy, of course, is not unique to this family. Matrix factorization models<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">1</a></sup> have a similar redundancy. But we can eliminate some of this redundancy.</p>

<p>Since \(\mathbf{A}_0\) is symmetric, it has a spectral decomposition:</p>

\[\mathbf{A}_0 = \mathbf{U} \operatorname{diag}({\boldsymbol\mu}) \mathbf{U}^\intercal,\]

<p>where \(\boldsymbol \mu\) is the vector of eigenvalues in some predefined order, such as non-increasing or non-decreasing. Consequently, the model can be written as</p>

\[f(\mathbf{x}) = \lambda_k\left(\operatorname{diag}({\boldsymbol \mu}) + \sum_{i=1}^n x_i (\mathbf{U}^\intercal\mathbf{A}_i \mathbf{U})\right).\]

<p>Thus, we can assume that the matrix \(\mathbf{A}_0\) is, for example, diagonal and non-decreasing, without losing any representation power, and assume our model is always of the form:</p>

\[f(\mathbf{x}) = \lambda_k \left( \operatorname{diag}({\boldsymbol \mu}) + \sum_{i=1}^n x_i \mathbf{A}_i \right),\]

<p>where \(\boldsymbol \mu\) is a non-decreasing vector, and \(\mathbf{A}_i\) are symmetric matrices. So let’s implement such a model in PyTorch. To that end, we will need a way to represent a non-decreasing vector, which is quite easy - use <code class="language-plaintext highlighter-rouge">torch.nn.softplus</code> to generate non-negative gaps, and sum them up. Also, I don’t know what is the right initialization for our \(\boldsymbol \mu\), so I chose uniformly spaced points between -1 and 1:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="k">class</span> <span class="nc">Nondecreasing</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="n">init</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">start</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="p">[:</span><span class="mi">1</span><span class="p">])</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">increments</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="p">.</span><span class="n">diff</span><span class="p">().</span><span class="n">expm1</span><span class="p">().</span><span class="n">log</span><span class="p">())</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">start</span><span class="p">,</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">start</span> <span class="o">+</span> <span class="n">nn</span><span class="p">.</span><span class="n">functional</span><span class="p">.</span><span class="n">softplus</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">increments</span><span class="p">).</span><span class="n">cumsum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="p">])</span>
</code></pre></div></div>

<p>Let’s try it out:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">Nondecreasing</span><span class="p">(</span><span class="mi">10</span><span class="p">)()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([-1.0000, -0.7778, -0.5556, -0.3333, -0.1111,  0.1111,  0.3333,  0.5556,
         0.7778,  1.0000], grad_fn=&lt;CatBackward0&gt;)
</code></pre></div></div>

<p>Appears to be working. Now, this may not be the best way to parameterize a non-decreasing vector, and you probably can think of other ways, but it appears to works reasonably well when we train models later in this post.</p>

<p>So now we can use it to implement a PyTorch module for the kind of functions we seek. The code is mostly straightforward, and the only thing requires explaining is the initialization of the matrices \(\mathbf{A}_i\), that we shall talk about right after the code snippet:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch.linalg</span> <span class="k">as</span> <span class="n">tla</span>

<span class="k">class</span> <span class="nc">MultivariateSpectral</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">num_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">eigval_idx</span> <span class="o">=</span> <span class="n">eigval_idx</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">mu</span> <span class="o">=</span> <span class="n">Nondecreasing</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">A</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span>
            <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">*</span> <span class="n">num_features</span><span class="p">)</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># batches of sum of x[i] * A[i]
</span>        <span class="n">nf</span><span class="p">,</span> <span class="n">dim</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span>
        <span class="n">feature_mat</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">@</span> <span class="bp">self</span><span class="p">.</span><span class="n">A</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="n">nf</span><span class="p">,</span> <span class="n">dim</span> <span class="o">*</span> <span class="n">dim</span><span class="p">)).</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>

        <span class="c1"># diag(mu) replicated per batch
</span>        <span class="n">bias_mat</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">diagflat</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">mu</span><span class="p">()).</span><span class="n">expand_as</span><span class="p">(</span><span class="n">feature_mat</span><span class="p">)</span>

        <span class="c1"># batched eigenvalue computation
</span>        <span class="n">eigvals</span> <span class="o">=</span> <span class="n">tla</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">bias_mat</span> <span class="o">+</span> <span class="n">feature_mat</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">eigvals</span><span class="p">[...,</span> <span class="bp">self</span><span class="p">.</span><span class="n">eigval_idx</span><span class="p">]</span>
</code></pre></div></div>

<p>Regarding initialization, I am making an educated guess here. It is known<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">2</a></sup> that the spectrum of \(n \times n\) matrices with random Gaussian entries converges to the semicircle distribution in \([-2\sqrt{n}, 2 \sqrt{n}]\) as \(n\) grows. Moreover, since we will be summing up <code class="language-plaintext highlighter-rouge">num_features</code> matrices, it makes sense to initialise our matrices to a normal distribution with a standard deviation of \((\sqrt{n} \cdot \mathtt{num\_features})^{-1}\). Here, too, I don’t know if this is the best initialization, but it works reasonably well.</p>

<p>As a sanity test, let’s try learning the concave function from the previous post:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
    <span class="k">return</span> <span class="o">-</span><span class="n">torch</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">y</span><span class="o">+</span><span class="mf">0.5</span><span class="p">)</span> <span class="o">+</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="o">-</span><span class="n">y</span><span class="o">+</span><span class="mf">0.5</span><span class="p">))</span>
 
<span class="c1"># sample 10000 points on the graph of the function
</span><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">10000</span><span class="p">).</span><span class="n">uniform_</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">10000</span><span class="p">).</span><span class="n">uniform_</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">xy</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="o">+</span> <span class="mf">0.2</span> <span class="o">*</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">10000</span><span class="p">)</span>
</code></pre></div></div>

<p>Here is a simple training loop to see if the loss decreases - let’s fit a <em>concave</em> model (smallest eigenvalue) with \(5 \times 5\) matrices</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">math</span>
<span class="kn">from</span> <span class="nn">itertools</span> <span class="kn">import</span> <span class="n">count</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">MultivariateSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">print_every</span> <span class="o">=</span> <span class="mi">100</span>

<span class="n">cum_loss</span> <span class="o">=</span> <span class="mf">0.</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">xyb</span><span class="p">,</span> <span class="n">zb</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">count</span><span class="p">(),</span> <span class="n">xy</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">z</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)):</span>
    <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">xyb</span><span class="p">)</span> <span class="o">-</span> <span class="n">zb</span><span class="p">).</span><span class="n">square</span><span class="p">().</span><span class="n">mean</span><span class="p">()</span>
    <span class="n">cum_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="p">.</span><span class="n">detach</span><span class="p">().</span><span class="n">item</span><span class="p">()</span>

    <span class="n">optim</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
    <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
    <span class="n">optim</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
        
    <span class="k">if</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">print_every</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
        <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Loss = </span><span class="si">{</span><span class="n">cum_loss</span> <span class="o">/</span> <span class="n">print_every</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
        <span class="n">cum_loss</span> <span class="o">=</span> <span class="mf">0.</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Loss = 0.4582
Loss = 0.1355
Loss = 0.0774
Loss = 0.0649
Loss = 0.0626
Loss = 0.0520
Loss = 0.0534
Loss = 0.0516
Loss = 0.0464
Loss = 0.0468
</code></pre></div></div>

<p>OK. The model appears to be learning - the loss is decreasing. So now that we have eliminated most of the redundancy, let’s move on to more interesting stuff.</p>

<h1 id="spectral-stability-and-its-consequences">Spectral stability and its consequences</h1>

<p>First, let us recall that any matrix has an associated <em>operator norm</em> - the maximum amount by which it can stretch a unit vector:</p>

\[\|\mathbf{A}\|_{\mathrm{op}} = \max_{\mathbf{x}} \left\{ \|\mathbf{A} \mathbf{x} \|_2 : \|x\|_2 = 1 \right\}\]

<p>We have <code class="language-plaintext highlighter-rouge">np.linalg.norm</code> and <code class="language-plaintext highlighter-rouge">torch.linalg.norm</code>  to reliably compute it. Why are we recalling it?  Turns out there is a useful consequence of the  <a href="https://en.wikipedia.org/wiki/Weyl%27s_inequality">Weyl’s inequality</a> for symmetric matrices - spectral stability:</p>

\[\vert \lambda_k(\mathbf{A} + \mathbf{B}) - \lambda_k(\mathbf{A}) \vert \leq \|\mathbf{B}\|_{\mathrm{op}}.\]

<p>So if we take a symmetric matrix \(\mathbf{A}\) and “corrupt” or “perturb” it by another symmetric matrix \(\mathbf{B}\), the resulting eigenvalues do not change by more than \(\|\mathbf{B}\|_{\mathrm{op}}\).</p>

<p>Now, consider our model family, and suppose that the first feature \(x_1\) was perturbed by some noise \(\varepsilon\). By the spectral stability property, our model’s output will not change by more than \(\lvert\varepsilon\rvert \| \mathbf{A}_1 \|_{\mathrm{op}}\). And in general, if our feature vector was perturbed by some noise \(\boldsymbol \varepsilon\), we have:</p>

\[|f(\mathbf{x} + {\boldsymbol \varepsilon}) - f(\mathbf{x})| \leq \Biggl \|\sum_{i=1}^n \varepsilon_i \mathbf{A}_i \Biggr\|_{\mathrm{op}} \leq \sum_{i=1}^n |\varepsilon_i| \| \mathbf{A}_i \|_{\mathrm{op}}\]

<p>Now, we have two ways to interpret this bound. First, from the standpoint of robustness - we have a direct bound on the possible change of the prediction as a function of the noise \(\boldsymbol \varepsilon\). For example, if we care about the \(\ell_2\) norm of the noise and want to know what happens when \(\|\boldsymbol \varepsilon\|_2 \leq \alpha\), the Cauchy-Schwarz inequality implies that the model’s prediction changes by at most \(\alpha \sqrt{\sum_{i=1}^n \| \mathbf{A}_i \|^2_{\mathrm{op}}}\).</p>

<p>The second way to think of the bound is from the standpoint of interpretability: one notion of feature importance is a worst-case sensitivity bound. The quantity \(\| \mathbf{A}_i \|_{\mathrm{op}}\) upper-bounds how much the prediction can change when only feature \(x_i\) is perturbed, because a small change of \(\varepsilon\) to feature \(x_i\) will make the model’s prediction change by at most \(\varepsilon \| \mathbf{A}_i \|_{\mathrm{op}}\). So this operator norm is a bound on the <em>effect</em> of feature \(x_i\) on the model’s prediction, just like the magnitude of the coefficients in a linear model.</p>

<p>We can use this knowledge in two ways. First, having trained a model, we can interrogate it for its robustness / feature-importance properties by computing the spectral norms of all feature matrices. Second, we can try to impose a regularization term that imposes a limit on these operator norms. So let’s try the first idea - of observing the operator norms.</p>

<h1 id="observing-stability-bounds-in-practice">Observing stability bounds in practice</h1>

<p>We will do it with our beloved California Housing data-set that I use a lot in my blog posts, simply because it’s there on Colab. So let’s load it:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>

<span class="n">train_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s">'sample_data/california_housing_train.csv'</span><span class="p">)</span>
<span class="n">test_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s">'sample_data/california_housing_test.csv'</span><span class="p">)</span>
</code></pre></div></div>

<p>You may recall from our previous blog posts, that the dataset has four very skewed columns that we typically apply a log transformation to:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="n">skewed_columns</span> <span class="o">=</span> <span class="p">[</span><span class="s">'total_rooms'</span><span class="p">,</span> <span class="s">'total_bedrooms'</span><span class="p">,</span> <span class="s">'population'</span><span class="p">,</span> <span class="s">'households'</span><span class="p">]</span>
<span class="n">train_df</span><span class="p">[</span><span class="n">skewed_columns</span><span class="p">]</span> <span class="o">=</span> <span class="n">train_df</span><span class="p">[</span><span class="n">skewed_columns</span><span class="p">].</span><span class="nb">apply</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">)</span>
<span class="n">test_df</span><span class="p">[</span><span class="n">skewed_columns</span><span class="p">]</span> <span class="o">=</span> <span class="n">test_df</span><span class="p">[</span><span class="n">skewed_columns</span><span class="p">].</span><span class="nb">apply</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">)</span>
</code></pre></div></div>

<p>Our final data preprocessing step is plain scaling:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.preprocessing</span> <span class="kn">import</span> <span class="n">StandardScaler</span>

<span class="n">scaler</span> <span class="o">=</span> <span class="n">StandardScaler</span><span class="p">().</span><span class="n">set_output</span><span class="p">(</span><span class="n">transform</span><span class="o">=</span><span class="s">'pandas'</span><span class="p">)</span>
<span class="n">train_scaled</span> <span class="o">=</span> <span class="n">scaler</span><span class="p">.</span><span class="n">fit_transform</span><span class="p">(</span><span class="n">train_df</span><span class="p">)</span>
<span class="n">test_scaled</span> <span class="o">=</span> <span class="n">scaler</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test_df</span><span class="p">)</span>

<span class="n">label_scale</span> <span class="o">=</span> <span class="n">scaler</span><span class="p">.</span><span class="n">scale_</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
</code></pre></div></div>

<p>We remember the scale of the last column, the label of the data-set, because we want our evaluation metrics in the original units of the label, not in the normalized units. Before training, let’s put our training data in PyTorch tensors - it will be more convenient:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">as_tensor</span>

<span class="k">def</span> <span class="nf">to_tensors</span><span class="p">(</span><span class="n">df</span><span class="p">):</span>
    <span class="n">target</span> <span class="o">=</span> <span class="s">'median_house_value'</span>
    <span class="k">return</span> <span class="p">(</span>
        <span class="n">as_tensor</span><span class="p">(</span><span class="n">df</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">).</span><span class="n">values</span><span class="p">),</span> 
        <span class="n">as_tensor</span><span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="n">target</span><span class="p">].</span><span class="n">values</span><span class="p">)</span>
    <span class="p">)</span>

<span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">to_tensors</span><span class="p">(</span><span class="n">train_scaled</span><span class="p">)</span>
<span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">to_tensors</span><span class="p">(</span><span class="n">test_scaled</span><span class="p">)</span>

<span class="n">num_features</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">n_train</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">X_train</span><span class="p">)</span>
</code></pre></div></div>

<p>Alright! So now let’s write our training loop. Here is a fairly standard PyTorch loop for one epoch:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">train_epoch</span><span class="p">(</span>
        <span class="n">device</span><span class="p">,</span> <span class="n">net</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">criterion</span><span class="p">,</span> <span class="n">regularizer</span><span class="p">,</span> <span class="n">X_batches</span><span class="p">,</span> <span class="n">y_batches</span>
    <span class="p">):</span>
    <span class="n">epoch_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">X_batches</span><span class="p">,</span> <span class="n">y_batches</span><span class="p">):</span>
        <span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
        <span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>
        <span class="n">cost</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">+</span> <span class="n">regularizer</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
        <span class="n">cost</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>

        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="n">epoch_loss</span> <span class="o">+=</span> <span class="n">loss</span> <span class="o">*</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
    <span class="k">return</span> <span class="p">(</span><span class="n">epoch_loss</span> <span class="o">/</span> <span class="n">n_train</span><span class="p">).</span><span class="n">cpu</span><span class="p">().</span><span class="n">item</span><span class="p">()</span>
</code></pre></div></div>

<p>The regularizer will become useful later in this post - it’s just an additional penalty beyond the loss. And here is our pretty-standard training loop for the model, but with a twist: we <code class="language-plaintext highlighter-rouge">yield</code> intermediate results. Why? It’s convenient to work with - we fully decouple training code from reporting / plotting code:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">train_model_stream</span><span class="p">(</span>
        <span class="n">net</span><span class="p">,</span> <span class="n">criterion</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">n_epochs</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">,</span> <span class="n">regularizer</span><span class="o">=</span><span class="bp">None</span>
    <span class="p">):</span>
    <span class="n">device</span> <span class="o">=</span> <span class="s">'cuda'</span> <span class="k">if</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s">'cpu'</span>
    <span class="n">regularizer</span> <span class="o">=</span> <span class="n">regularizer</span> <span class="ow">or</span> <span class="p">(</span><span class="k">lambda</span> <span class="n">model</span><span class="p">:</span> <span class="mf">0.</span><span class="p">)</span> <span class="c1"># by default - no reg.
</span>
    <span class="n">net</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    <span class="n">X_train_batches</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">).</span><span class="n">split</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
    <span class="n">y_train_batches</span> <span class="o">=</span> <span class="n">y_train</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">).</span><span class="n">split</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
    <span class="n">X_test_device</span> <span class="o">=</span> <span class="n">X_test</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    <span class="n">y_test_device</span> <span class="o">=</span> <span class="n">y_test</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>

    <span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">AdamW</span><span class="p">(</span><span class="n">net</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">n_epochs</span><span class="p">):</span>
        <span class="n">train_loss</span> <span class="o">=</span> <span class="n">train_epoch</span><span class="p">(</span>
            <span class="n">device</span><span class="p">,</span> <span class="n">net</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">criterion</span><span class="p">,</span> <span class="n">regularizer</span><span class="p">,</span>
            <span class="n">X_train_batches</span><span class="p">,</span> <span class="n">y_train_batches</span>
        <span class="p">)</span>

        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="n">test_loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">net</span><span class="p">(</span><span class="n">X_test_device</span><span class="p">),</span> <span class="n">y_test_device</span><span class="p">)</span>
            <span class="n">test_loss</span> <span class="o">=</span> <span class="n">test_loss</span><span class="p">.</span><span class="n">cpu</span><span class="p">().</span><span class="n">item</span><span class="p">()</span>

        <span class="k">yield</span> <span class="p">{</span>
            <span class="s">'step'</span><span class="p">:</span> <span class="n">epoch</span><span class="p">,</span>
            <span class="s">'model'</span><span class="p">:</span> <span class="n">net</span><span class="p">,</span>
            <span class="s">'train_error'</span><span class="p">:</span> <span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">train_loss</span><span class="p">)</span> <span class="o">*</span> <span class="n">label_scale</span><span class="p">,</span>
            <span class="s">'test_error'</span><span class="p">:</span> <span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">test_loss</span><span class="p">)</span> <span class="o">*</span> <span class="n">label_scale</span><span class="p">,</span>
        <span class="p">}</span>
</code></pre></div></div>

<p>This is where we use the <code class="language-plaintext highlighter-rouge">label_scale</code> we previously stored - to report the error in the units of the original labels, not the normalized ones. Let’s try a few epochs, to see how it works:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">MultivariateSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">MSELoss</span><span class="p">()</span>
<span class="k">for</span> <span class="n">event</span> <span class="ow">in</span> <span class="n">train_model_stream</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">criterion</span><span class="p">,</span> <span class="n">n_epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">):</span>
    <span class="k">print</span><span class="p">(</span><span class="n">event</span><span class="p">[</span><span class="s">'step'</span><span class="p">],</span> <span class="n">event</span><span class="p">[</span><span class="s">'train_error'</span><span class="p">],</span> <span class="n">event</span><span class="p">[</span><span class="s">'test_error'</span><span class="p">],</span> <span class="n">sep</span><span class="o">=</span><span class="s">'</span><span class="se">\t</span><span class="s">'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>1	109072.75280327671	99536.88016316161
2	95295.4091119307	89540.08167748511
3	86628.42283455568	82857.71197495822
4	80906.37794419663	78535.01666073261
5	77319.691256485	75821.19027087984
</code></pre></div></div>

<p>OK - model appears to be training nicely. This trick of yielding lets us do interesting stuff - for example, we can create a new stream that yields train and test errors, together with the spectral norms of the feature matrices:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">add_spectral_norms</span><span class="p">(</span><span class="n">stream</span><span class="p">):</span>
    <span class="k">for</span> <span class="n">event</span> <span class="ow">in</span> <span class="n">stream</span><span class="p">:</span>
        <span class="n">model</span> <span class="o">=</span> <span class="n">event</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span>
        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="c1"># remember - we're using only lower-triangular part of each A_i
</span>            <span class="n">matrices_sym</span> <span class="o">=</span> \
                <span class="n">model</span><span class="p">.</span><span class="n">A</span><span class="p">.</span><span class="n">tril</span><span class="p">()</span> <span class="o">+</span> <span class="n">model</span><span class="p">.</span><span class="n">A</span><span class="p">.</span><span class="n">tril</span><span class="p">(</span><span class="n">diagonal</span><span class="o">=-</span><span class="mi">1</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span>
            <span class="n">norms</span> <span class="o">=</span> <span class="n">tla</span><span class="p">.</span><span class="n">matrix_norm</span><span class="p">(</span><span class="n">matrices_sym</span><span class="p">,</span> <span class="nb">ord</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
            <span class="n">norms</span> <span class="o">=</span> <span class="n">norms</span><span class="p">.</span><span class="n">ravel</span><span class="p">().</span><span class="n">cpu</span><span class="p">().</span><span class="n">tolist</span><span class="p">()</span>
        
        <span class="k">yield</span> <span class="p">{</span>
            <span class="s">'step'</span><span class="p">:</span> <span class="n">event</span><span class="p">[</span><span class="s">'step'</span><span class="p">],</span>
            <span class="s">'train_error'</span><span class="p">:</span> <span class="n">event</span><span class="p">[</span><span class="s">'train_error'</span><span class="p">],</span>
            <span class="s">'test_error'</span><span class="p">:</span> <span class="n">event</span><span class="p">[</span><span class="s">'test_error'</span><span class="p">],</span>
        <span class="p">}</span> <span class="o">|</span> <span class="p">{</span>
            <span class="sa">f</span><span class="s">'norm_</span><span class="si">{</span><span class="n">feature_name</span><span class="si">}</span><span class="s">'</span><span class="p">:</span> <span class="n">norm</span> 
            <span class="k">for</span> <span class="n">feature_name</span><span class="p">,</span> <span class="n">norm</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">feature_names</span><span class="p">,</span> <span class="n">norms</span><span class="p">)</span>
        <span class="p">}</span>
</code></pre></div></div>

<p>Let’s try it out. This time we’ll use the <code class="language-plaintext highlighter-rouge">rich</code> library for pretty printing, since the regular Python print doesn’t produce a nice output. So here are 2 training epochs:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">rich.pretty</span> <span class="kn">import</span> <span class="n">pprint</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">MultivariateSpectral</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">MSELoss</span><span class="p">()</span>
<span class="k">for</span> <span class="n">event</span> <span class="ow">in</span> <span class="n">add_spectral_norms</span><span class="p">(</span><span class="n">train_model_stream</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">criterion</span><span class="p">,</span> <span class="n">n_epochs</span><span class="o">=</span><span class="mi">2</span><span class="p">)):</span>
    <span class="n">pprint</span><span class="p">(</span><span class="n">event</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  {
    'step': 1,
    'train_error': 105523.77536281559,
    'test_error': 97731.75678823328,
    'norm_longitude': 0.12640513479709625,
    'norm_latitude': 0.1759399175643921,
    'norm_housing_median_age': 0.15713410079479218,
    'norm_total_rooms': 0.17977216839790344,
    'norm_total_bedrooms': 0.16544003784656525,
    'norm_population': 0.19817979633808136,
    'norm_households': 0.2670281231403351,
    'norm_median_income': 0.33458226919174194
}
{
    'step': 2,
    'train_error': 94575.6796415192,
    'test_error': 88766.7939206047,
    'norm_longitude': 0.15076042711734772,
    'norm_latitude': 0.1756560057401657,
    'norm_housing_median_age': 0.1784549206495285,
    'norm_total_rooms': 0.18125228583812714,
    'norm_total_bedrooms': 0.15172506868839264,
    'norm_population': 0.19934602081775665,
    'norm_households': 0.26696428656578064,
    'norm_median_income': 0.47850197553634644
}


</code></pre></div></div>

<p>Nice! So now we can iterate and do live-plotting of everything!  This is a lengthy function with mostly boilerplate that plots two graphs - one with train/test errors, and another one with spectral norms of feature matrices. I added comments to make the code clear, but the principle is simple: we create empty plots, and gradually update them as new events arrive.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="k">def</span> <span class="nf">plot_progress</span><span class="p">(</span><span class="n">events</span><span class="p">,</span> <span class="n">max_step</span><span class="p">):</span>
    <span class="c1"># create a plot with two axes - one for errors, one for norms
</span>    <span class="n">fig</span><span class="p">,</span> <span class="p">(</span><span class="n">err_ax</span><span class="p">,</span> <span class="n">norm_ax</span><span class="p">)</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span>
        <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">8</span><span class="p">),</span> <span class="n">layout</span><span class="o">=</span><span class="s">'constrained'</span>
    <span class="p">)</span>

    <span class="c1"># create empty line objects
</span>    <span class="k">def</span> <span class="nf">plot_empty</span><span class="p">(</span><span class="n">ax</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">([],</span> <span class="p">[],</span> <span class="n">label</span><span class="o">=</span><span class="n">label</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>

    <span class="n">line_dict</span> <span class="o">=</span> <span class="p">{</span>
        <span class="s">'train_error'</span><span class="p">:</span> <span class="n">plot_empty</span><span class="p">(</span><span class="n">err_ax</span><span class="p">,</span> <span class="s">'train error'</span><span class="p">),</span>
        <span class="s">'test_error'</span><span class="p">:</span> <span class="n">plot_empty</span><span class="p">(</span><span class="n">err_ax</span><span class="p">,</span> <span class="s">'test error'</span><span class="p">),</span>
    <span class="p">}</span> <span class="o">|</span> <span class="p">{</span>
        <span class="sa">f</span><span class="s">'norm_</span><span class="si">{</span><span class="n">feature_name</span><span class="si">}</span><span class="s">'</span><span class="p">:</span> <span class="n">plot_empty</span><span class="p">(</span><span class="n">norm_ax</span><span class="p">,</span> <span class="n">feature_name</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">feature_name</span> <span class="ow">in</span> <span class="n">feature_names</span>
    <span class="p">}</span>

    <span class="c1"># setup axis properties
</span>    <span class="n">err_ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Error"</span><span class="p">)</span>
    <span class="n">norm_ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Matrix norms"</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">ax</span> <span class="ow">in</span> <span class="p">(</span><span class="n">err_ax</span><span class="p">,</span> <span class="n">norm_ax</span><span class="p">):</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"Step"</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">max_step</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">grid</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>


    <span class="c1"># display figure and obtain its handle
</span>    <span class="n">h</span> <span class="o">=</span> <span class="n">display</span><span class="p">(</span><span class="n">fig</span><span class="p">,</span> <span class="n">display_id</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">close</span><span class="p">(</span><span class="n">fig</span><span class="p">)</span>

    <span class="c1"># iterate over events and update the plot
</span>    <span class="n">min_test_error</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s">'inf'</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">event</span> <span class="ow">in</span> <span class="n">events</span><span class="p">:</span>
        <span class="n">step</span> <span class="o">=</span> <span class="n">event</span><span class="p">[</span><span class="s">'step'</span><span class="p">]</span>
        <span class="n">min_test_error</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">min_test_error</span><span class="p">,</span> <span class="n">event</span><span class="p">[</span><span class="s">'test_error'</span><span class="p">])</span>
        <span class="n">err_ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">'Error (min test err = </span><span class="si">{</span><span class="n">min_test_error</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">)'</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">line_dict</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
            <span class="n">value</span> <span class="o">=</span> <span class="n">event</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
            <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">line</span><span class="p">.</span><span class="n">get_data</span><span class="p">(</span><span class="n">orig</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
            <span class="n">line</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">step</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">value</span><span class="p">))</span>

        <span class="k">for</span> <span class="n">axs</span> <span class="ow">in</span> <span class="p">(</span><span class="n">err_ax</span><span class="p">,</span> <span class="n">norm_ax</span><span class="p">):</span>
            <span class="n">axs</span><span class="p">.</span><span class="n">relim</span><span class="p">()</span>
            <span class="n">axs</span><span class="p">.</span><span class="n">autoscale_view</span><span class="p">()</span>

        <span class="n">fig</span><span class="p">.</span><span class="n">canvas</span><span class="p">.</span><span class="n">draw</span><span class="p">()</span>
        <span class="n">h</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="n">fig</span><span class="p">)</span>
</code></pre></div></div>

<p>Alright! Let’s use it to train a mid-eigenvalue model with \(5 \times 5\) matrices:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">live_plot_training</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">n_epochs</span><span class="p">):</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">MultivariateSpectral</span><span class="p">(</span>
        <span class="n">num_features</span><span class="o">=</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span>
    <span class="p">)</span>
    <span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">MSELoss</span><span class="p">()</span>
    <span class="n">events</span> <span class="o">=</span> <span class="n">add_spectral_norms</span><span class="p">(</span><span class="n">train_model_stream</span><span class="p">(</span>
        <span class="n">model</span><span class="p">,</span> <span class="n">criterion</span><span class="p">,</span> <span class="n">n_epochs</span><span class="o">=</span><span class="n">n_epochs</span>
    <span class="p">))</span>
    <span class="n">plot_progress</span><span class="p">(</span><span class="n">events</span><span class="p">,</span> <span class="n">max_step</span><span class="o">=</span><span class="n">n_epochs</span><span class="p">)</span>

<span class="n">live_plot_training</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">500</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_props_norms_5.png" alt="pow_spec_props_norms_5" /></p>

<p>OK. We can see that the model is learning, and after 500 epochs we observe that the resulting model’s strongest three features are longitude, latitude, and population. What happens when we increase model size? Let’s try \(15 \times 15\) matrices:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">live_plot_training</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">500</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_props_norms_15.png" alt="pow_spec_props_norms_15" /></p>

<p>We see that the test loss decreases with the model size, and even though the ranking between features is slightly different, the three strongest features remain longitude, latitude, and population. But we also see something else - the matrix norms continue growing. Apparently, after 500 epochs, the model’s parameters do not appear to be converging. Perhaps a more thorough hyper-parameter tuning would help, I don’t know. But I chose a conservative option of a small learning rate and many epochs for a reason - to show that scaling model size improves performance, while keeping our model’s ability to be interpretable almost as if it was linear.</p>

<p>Let’s go even further up, to \(30 \times 30\) matrices:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">live_plot_training</span><span class="p">(</span><span class="mi">30</span><span class="p">,</span> <span class="mi">500</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_props_norms_30.png" alt="pow_spec_props_norms_30" /></p>

<p>We see that the train and test errors go further down, and the three features previously at the top remain there. Again - scaling up improves performance, while keeping interpretability and computable robustness bounds.</p>

<p>So what we got here is really interesting! We have a model that is nonlinear and improves with scaling, while remaining interpretable in terms of feature sensitivity / importance, and we have an easy way to compute global sensitivity bounds (which can be loose).</p>

<p>As a reference, if you try fitting a gradient-boosted decision forest using XGBoost, you’ll observe a test error of approximately $48,000. So the eigenvalue model we see here isn’t close to what trees can achieve, but tree ensembles are often discontinuous and don’t come with simple global sensitivity/Lipschitz certificates in the same way. So it’s a tradeoff.</p>

<h1 id="sensitivity-control">Sensitivity control</h1>

<p>Another way we can use our understanding of the stability properties is to regularize the model by either imposing a bound on the maximum spectral norm, or adding a regularization term that penalizes the spectral norms, so our training code will be minimizing</p>

\[\min_{\mathbf{A}_{1:n}, \boldsymbol\mu} \quad \underbrace{\frac{1}{N} \sum_{i=1}^N (f(\mathbf{x}_i;\mathbf{A}_{1:n}, {\boldsymbol \mu}) - y_i)^2}_{\mathrm{loss}} + \underbrace{\alpha \sum_{i=1}^n \| \mathbf{A}_i \|_{\mathrm{op}}}_{\mathrm{penalty}}\]

<p>This is where we shall use the <code class="language-plaintext highlighter-rouge">regularizer</code> parameter of our training function that I promised you:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">live_plot_reg_training</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">n_epochs</span><span class="p">,</span> <span class="n">reg_coef</span><span class="p">):</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">MultivariateSpectral</span><span class="p">(</span>
        <span class="n">num_features</span><span class="o">=</span><span class="n">num_features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span>
    <span class="p">)</span>

    <span class="k">def</span> <span class="nf">penalty</span><span class="p">(</span><span class="n">net</span><span class="p">):</span>
        <span class="n">matrices_sym</span> <span class="o">=</span> \
            <span class="n">net</span><span class="p">.</span><span class="n">A</span><span class="p">.</span><span class="n">tril</span><span class="p">()</span> <span class="o">+</span> <span class="n">net</span><span class="p">.</span><span class="n">A</span><span class="p">.</span><span class="n">tril</span><span class="p">(</span><span class="n">diagonal</span><span class="o">=-</span><span class="mi">1</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span>
        <span class="n">norms</span> <span class="o">=</span> <span class="n">tla</span><span class="p">.</span><span class="n">matrix_norm</span><span class="p">(</span><span class="n">matrices_sym</span><span class="p">,</span> <span class="nb">ord</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">reg_coef</span> <span class="o">*</span> <span class="n">norms</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span>

    <span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">MSELoss</span><span class="p">()</span>
    <span class="n">events</span> <span class="o">=</span> <span class="n">add_spectral_norms</span><span class="p">(</span><span class="n">train_model_stream</span><span class="p">(</span>
        <span class="n">model</span><span class="p">,</span> <span class="n">criterion</span><span class="p">,</span> <span class="n">n_epochs</span><span class="o">=</span><span class="n">n_epochs</span><span class="p">,</span> <span class="n">regularizer</span><span class="o">=</span><span class="n">penalty</span>
    <span class="p">))</span>
    <span class="n">plot_progress</span><span class="p">(</span><span class="n">events</span><span class="p">,</span> <span class="n">max_step</span><span class="o">=</span><span class="n">n_epochs</span><span class="p">)</span>
</code></pre></div></div>

<p>Let’s try it out with \(15 \times 15\) matrices:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">live_plot_reg_training</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="mf">1e-3</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_props_norms_reg_15.png" alt="pow_spec_props_norms_reg_15" /></p>

<p>We can see that the spectral norms are smaller than our previous attempt with \(15 \times 15\) matrices above, norm growth appears to stabilize, but performance appears similar. Just the gap between the top four features and the rest of the features became more pronounced - that’s the effect of delicate regularization. A larger regularization coefficient may even drive some of the matrices towards zero, similarly to \(\ell_1\) regularization in Lasso.</p>

<p>Imposing such a regularizer with standard PyTorch optimizers, rather than a dedicated optimizer, may not be the optimal (pun intended!) thing to do, and some of you can probably think of better ways. But that’s beside the point - the point is that we can, in principle, regularize the spectral norm to control the model’s sensitivity to feature perturbations. And that is quite powerful.</p>

<p>So now after we’ve seen plenty of stuff - it’s time for a recap.</p>

<h1 id="summary">Summary</h1>

<p>We saw that matrix eigenvalues let us find a nice sweet-spot between several opposing forces - performance, robustness, and interpretability. Beyond just models for tabular data, this nice idea can also be employed for another use case we haven’t yet discussed - ensembling. There, too, we care about the ensemble’s prediction to behave “sensibly” w.r.t the predictions of the individual models, and there too we may care about robustness and interpretability. So it’s nice to have a learnable ensembling technique that both improves with scaling, but remains robust and somewhat interpretable.</p>

<p>We will study other mathematical properties in future posts that will let us understand on a deeper level what kind of information we can elicit from those models, but as of now we have a slightly more urgent concern: training is slow. We need many epochs, and each epoch is expensive. This makes experimentation hard - our feedback loop is slow as well. So this is something we shall try to address in the next post!</p>

<p><strong>References</strong></p>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:2" role="doc-endnote">
      <p>Koren, Y., Bell, R., &amp; Volinsky, C. (2009). Matrix factorization techniques for recommender systems. <em>Computer</em>, <em>42</em>(8), 30-37. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:1" role="doc-endnote">
      <p>Wigner, E. P. (1958). On the distribution of the roots of certain symmetric matrices. <em>Annals of Mathematics</em>, <em>67</em>(2), 325-327. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><category term="machine learning" /><category term="eigenvalue models" /><category term="adversarial robustness" /><category term="interpretability" /><category term="spectral norm" /><category term="operator norm" /><category term="regularization" /><summary type="html"><![CDATA[Robustness, interpretability, and scaling of eigenvalue models: stability bounds from Weyl's inequality, operator-norm feature importance, and regularization experiments for tabular data.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="https://alexshtf.github.io/assets/pow_spec_props_norms_reg_15.png" /><media:content medium="image" url="https://alexshtf.github.io/assets/pow_spec_props_norms_reg_15.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Behold the power of the spectrum!</title><link href="https://alexshtf.github.io/2025/12/16/Spectrum.html" rel="alternate" type="text/html" title="Behold the power of the spectrum!" /><published>2025-12-16T00:00:00+00:00</published><updated>2025-12-16T00:00:00+00:00</updated><id>https://alexshtf.github.io/2025/12/16/Spectrum</id><content type="html" xml:base="https://alexshtf.github.io/2025/12/16/Spectrum.html"><![CDATA[<p align="center">
  <a href="https://colab.research.google.com/github/alexshtf/alexshtf.github.io/blob/master/assets/spectrum_power_intro.ipynb" target="_blank" rel="noopener">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" />
  </a>
  <a href="https://huggingface.co/spaces/alexshtf/spectral_neuron_playground" target="_blank" rel="noopener">
    <picture>
      <source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-md-dark.svg" />
      <img alt="Open in HF Spaces" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-md.svg" />
    </picture>
  </a>  
</p>

<h1 id="intro">Intro</h1>

<p>When trying to model a complicated relationship between features, our go-to architectures are typically either neural networks or decision trees. They are well-established, well-studied, and have an abundance of software for training them. So why not?</p>

<p>But sometimes we have some additional requirements. Maybe we want our model to be an increasing / decreasing / convex / concave function of one or more feature. Perhaps we want to measure diminishing returns, and we want it to be both increasing and concave. Or maybe we care about the sensitivity of our model to noise, and want to certify its Lipschitz constant: what happens to the prediction if we slightly change the input?</p>

<p>In the recent <a href="https://www.youtube.com/watch?v=aR20FWCCjAs">interview</a> with Ilya Sutskever there was one interesting insight - perhaps we need neurons to do more compute than they do now. And it immediately rang a bell, and returned me to my Ph.D days - what would be this “more compute” that is useful? Well, maybe a neuron can solve a small optimization problem!   So in this post we shall explore an idea that most optimization researchers are familiar with - the eigenvalues of a matrix are solutions to optimization problems. So what can one such neuron do? Turns out quite a lot! This is exactly what we shall explore in this post.</p>

<p>Why eigenvalues? Well, because of a unique combination of three properties. They can model fairly complicated functions, we have a lot of theoretical machinery to reason about them, and we can stand on the shoulders of giants and reuse the vast talent and resources invested over decades in their reliable computation.</p>

<p>The notebook for reproducing all the results is <a href="https://github.com/alexshtf/alexshtf.github.io/blob/master/assets/spectrum_power_intro.ipynb">here</a>. Feel free to deploy it to Colab and play with it.</p>

<h1 id="univariate-functions">Univariate functions</h1>

<p>Let’s begin our adventures with a simple case - function of one variable. Suppose we’re given <em>symmetric</em> matrices, \(\mathbf{A}\) and \(\mathbf{B}\), and we define the function:</p>

\[f(x)=\lambda_k(\mathbf{A} + x \mathbf{B}),\]

<p>where \(\lambda_k(\cdot)\) is the \(k\)-th smallest eigenvalue of the given matrix. In general, we will be interested in two things - what kind of functions can we represent, and whether we learn such functions from data.</p>

<p>Our exploration begins with plotting, so let’s implement such a function in a vectorized manner - it will take a vector of inputs, and produce a corresponding vector of outputs. Turns out SciPy has excellent tools just for that:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">scipy.linalg</span> <span class="k">as</span> <span class="n">sla</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="k">def</span> <span class="nf">univariate_spectral</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">xs</span><span class="p">):</span>
  <span class="s">"""Computes the vector y[i] = λₖ(A + B * xs[i])."""</span>
  
  <span class="c1"># support negative eigenvalue indices,
</span>  <span class="c1">#  e.g., k=-1 is the largest eigenvalue
</span>  <span class="n">k</span> <span class="o">=</span> <span class="n">k</span> <span class="o">%</span> <span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  
  <span class="c1"># create a batch of matrices, one for each entry in xs
</span>  <span class="n">mats</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="n">B</span> <span class="o">*</span> <span class="n">xs</span><span class="p">[...,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span>
  
  <span class="c1"># compute the k-th eigenvalue of each matrix
</span>  <span class="k">return</span> <span class="n">sla</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">mats</span><span class="p">,</span> <span class="n">subset_by_index</span><span class="o">=</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">)).</span><span class="n">squeeze</span><span class="p">()</span>
</code></pre></div></div>

<p>Note - we don’t have to pass symmetric matrices. <code class="language-plaintext highlighter-rouge">eigvalsh</code> treats its input as symmetric/Hermitian and reads only one triangle (the lower triangle by default). So what do the functions look like? Let’s take two \(3 \times 3\) matrices \(\mathbf{A}\) and \(\mathbf{B}\) and plot the three functions corresponding to the three eigenvalues:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="k">def</span> <span class="nf">plot_eigenfunctions</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">):</span>
  <span class="s">"""Plots λₖ(A + B * x) on a grid layout"""</span>
  <span class="n">fig</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span> <span class="o">*</span> <span class="n">n_cols</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">n_rows</span><span class="p">),</span> <span class="n">layout</span><span class="o">=</span><span class="s">'constrained'</span><span class="p">)</span>
  <span class="n">plot_xs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
  <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">axs</span><span class="p">.</span><span class="n">ravel</span><span class="p">()):</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_xs</span><span class="p">,</span> <span class="n">univariate_spectral</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">plot_xs</span><span class="p">))</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">'$</span><span class="se">\\</span><span class="s">lambda_</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s">(A + B * x)$'</span><span class="p">)</span>
  <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>  


<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">plot_eigenfunctions</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_uni_3.png" alt="pow_spec_uni_3" /></p>

<p>Interesting. The smallest eigenvalue is a concave function, the largest is convex, and the middle appears arbitrary. What happens if we take two \(9 \times 9\) matrices?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">plot_eigenfunctions</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_uni_9.png" alt="pow_spec_uni_9" /></p>

<p>Again, smallest eigenvalue is concave, largest is convex, and the eigenvalues in between are typically neither convex nor concave.  Coincidence?</p>

<p>Some of you probably know it is not a coincidence at all. Indeed, recalling elementary results from linear algebra, the smallest eigenvalue of a matrix can be alternatively written as a minimum of quadratic functions:</p>

\[\lambda_1(\mathbf{P}) = \min\{ \mathbf{x}^T \mathbf{P} \mathbf{x} : \| \mathbf{x} \|_2 = 1 \}.\]

<p>These functions may be quadratic in \(\mathbf{x}\), but they are <em>linear</em> in the matrix \(\mathbf{P}\). And a minimum of linear functions is concave. Similarly, the largest eigenvalue is:</p>

\[\lambda_n(\mathbf{P})=\max\{\mathbf{x}^T \mathbf{P} \mathbf{x} : \|\mathbf{x}\|_2=1\},\]

<p>and a maximum of linear functions is convex. So it means our “univariate neuron”, at least for \(\lambda_1\) and \(\lambda_n\), is indeed solving an optimization problem!</p>

<p>But what about the other eigenvalues? There’s a more general version of the above, somewhat less known in the ML community - the Courant-Fischer theorem, for characterizing the \(k\)-th smallest eigenvalue:</p>

\[\lambda_k(\mathbf{P}) = \max_{\mathbf{V} \in \mathbb{R}^{(k-1)\times n}} \min_{\mathbf{x}} \left\{ \mathbf{x}^T \mathbf{P}\mathbf{x} : \| \mathbf{x}\|_2=1,\mathbf{V}\mathbf{x}=\mathbf{0} \right\}.\]

<p>This one appears a bit hairy - so lets dissect it. It can be thought of a game between two players, one chooses \(\mathbf{V}\) and the other one chooses a unit vector \(\mathbf{x}\) that is <em>orthogonal</em> to the rows of \(\mathbf{V}\). The “outer” player aims to maximize their reward, and the “inner” one aims to harm the outer as much as possible.</p>

<p>Consequently, \(\lambda_k\) in general is also solution for an optimization problem. And the farther \(k\) is from the extremities of \(k=1\) and \(k=n\), the “crazier” function it models. If we go back to the plots and look, we can even see that the functions are not necessarily smooth. Indeed, they have “kinks”.</p>

<p>To get a more complete picture, let’s see what happens if we make \(\mathbf{B}\) be not just a symmetric matrix, but a positive-semidefinite one, meaning all its eigenvalues are either positive or zero. It’s easy to make one - just create a random matrix and zero-out all its negative eigenvalues:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">make_psd</span><span class="p">(</span><span class="n">B</span><span class="p">):</span>
  <span class="n">eigvals</span><span class="p">,</span> <span class="n">eigvecs</span> <span class="o">=</span> <span class="n">sla</span><span class="p">.</span><span class="n">eigh</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
  <span class="n">eigvals_pos</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">eigvals</span><span class="p">)</span>
  <span class="k">return</span> <span class="n">eigvecs</span> <span class="o">@</span> <span class="n">np</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">eigvals_pos</span><span class="p">)</span> <span class="o">@</span> <span class="n">eigvecs</span><span class="p">.</span><span class="n">T</span>
</code></pre></div></div>

<p>Now let’s plot!</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">plot_eigenfunctions</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">make_psd</span><span class="p">(</span><span class="n">B</span><span class="p">),</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_uni_inc_9.png" alt="pow_spec_uni_inc_9" /></p>

<p>That’s interesting - all functions are increasing! What if we take a negative-definite \(\mathbf{B}\)?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">plot_eigenfunctions</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="o">-</span><span class="n">make_psd</span><span class="p">(</span><span class="n">B</span><span class="p">),</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_uni_dec_9.png" alt="pow_spec_uni_dec_9" /></p>

<p>All are decreasing!</p>

<p>Again, this is not a coincidence. Turns out eigenvalue functions are <em>monotone</em> - if we take a matrix \(\mathbf{A}\) and add the matrix \(x \mathbf{B}\) whose eigenvalues are all non-negative, the entire spectrum of eigenvalues increases. So larger \(x\) results in larger eigenvalues, and vice versa, and we obtain an increasing function. The opposite happens when all eigenvalues of \(\mathbf{B}\) are nonpositive.</p>

<h2 id="beyond-univariate-functions">Beyond univariate functions</h2>

<p>To understand what happens beyond the univariate case, let’s look at functions of <em>two</em> variables given three symmetric matrices \(\mathbf{A}, \mathbf{B}, \mathbf{C}\):</p>

\[f(x, y) = \lambda_k(\mathbf{A} + x\mathbf{B} + y \mathbf{C})\]

<p>For plotting to be convenient, let’s first compute such \(f\) in a vectorized manner:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">bivariate_spectral</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">):</span>
  <span class="s">"""Computes the vector z[i] = λₖ(A + B * xs[i] + C * ys[i])."""</span>
  
  <span class="c1"># support negative eigenvalue indices,
</span>  <span class="c1">#  e.g., k=-1 is the largest eigenvalue
</span>  <span class="n">k</span> <span class="o">=</span> <span class="n">k</span> <span class="o">%</span> <span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  
  <span class="c1"># create a batch of matrices, one for each point (xs[i], ys[i])
</span>  <span class="n">mats</span> <span class="o">=</span> <span class="p">(</span>
      <span class="n">A</span> <span class="o">+</span> <span class="n">B</span> <span class="o">*</span> <span class="n">xs</span><span class="p">[...,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span> 
        <span class="o">+</span> <span class="n">C</span> <span class="o">*</span> <span class="n">ys</span><span class="p">[...,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span>
  <span class="p">)</span>
  
  <span class="c1"># compute the k-th eigenvalue of each matrix
</span>  <span class="k">return</span> <span class="n">sla</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">mats</span><span class="p">,</span> <span class="n">subset_by_index</span><span class="o">=</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">)).</span><span class="n">squeeze</span><span class="p">()</span>
</code></pre></div></div>

<p>Now we can create surface plots of all eigenvalue functions given the three matrices. Here is the plotting function - nothing special, just a grid of 3D surface plots:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">plot_eigenfunctions_2d</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">):</span>
  <span class="s">"""Plots z = λₖ(A + B * x + C * y) on a grid layout"""</span>
  <span class="n">fig</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span>
      <span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">n_cols</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">n_rows</span><span class="p">),</span>
      <span class="n">subplot_kw</span><span class="o">=</span><span class="p">{</span><span class="s">"projection"</span><span class="p">:</span> <span class="s">"3d"</span><span class="p">},</span> <span class="n">layout</span><span class="o">=</span><span class="s">'constrained'</span>
  <span class="p">)</span>
  
  <span class="n">plot_xs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">50</span><span class="p">)</span>
  <span class="n">grid_xs</span><span class="p">,</span> <span class="n">grid_ys</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">plot_xs</span><span class="p">,</span> <span class="n">plot_xs</span><span class="p">)</span>

  <span class="n">k</span> <span class="o">=</span> <span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">axs</span><span class="p">.</span><span class="n">ravel</span><span class="p">()[:</span><span class="n">k</span><span class="p">]):</span>
    <span class="n">grid_zs</span> <span class="o">=</span> <span class="n">bivariate_spectral</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">grid_xs</span><span class="p">,</span> <span class="n">grid_ys</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">plot_surface</span><span class="p">(</span><span class="n">grid_xs</span><span class="p">,</span> <span class="n">grid_ys</span><span class="p">,</span> <span class="n">grid_zs</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'viridis'</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">'$</span><span class="se">\\</span><span class="s">lambda_</span><span class="si">{</span><span class="mi">1</span> <span class="o">+</span> <span class="n">k</span><span class="si">}</span><span class="s">(A + B * x + C * y)$'</span><span class="p">)</span>
  <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>  
</code></pre></div></div>

<p>We’re all set. Let’s plot three eigenvalue functions corresponding to random \(3 \times 3\) matrices:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">plot_eigenfunctions_2d</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_bi_3.png" alt="pow_spec_bi_3" /></p>

<p>As expected, the smallest eigenvalue produces a concave function, the largest produces a convex one, and the mid eigenvalue produces some arbitrary shape. What if we increase the matrix size to \(9 \times 9\)?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">plot_eigenfunctions_2d</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_bi_9.png" alt="pow_spec_bi_9" /></p>

<p>Similar things happen. We can express more “complicated” functions, where the mid eigenvalue \(\lambda_5\) has the “craziest” shape, whereas the smallest eigenvalue produces a concave function and the largest yields a convex one.</p>

<p>I assume you can guess what happens when \(\mathbf{B}\) is negative semi-definite, and \(\mathbf{C}\) is positive semi-definite:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="n">plot_eigenfunctions_2d</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="o">-</span><span class="n">make_psd</span><span class="p">(</span><span class="n">B</span><span class="p">),</span> <span class="n">make_psd</span><span class="p">(</span><span class="n">C</span><span class="p">),</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_bi_9_nsd_psd.png" alt="pow_spec_bi_9_nsd_psd" /></p>

<p>As expected, all functions are <em>decreasing</em> in \(x\), and <em>increasing</em> in \(y\). Moreover, the smallest eigenvalue yields a concave function, whereas the largest yields a convex one.</p>

<p>Now why would it be interesting for ML people? Because we can <em>learn</em> the matrices from data! In fact, we can do even more than that - we can <em>predict</em> them. For example, consider a model for predicting the probability of winning a government contract given the contract document and a bid. We use an encoder to transform the document into a symmetric matrix \(\mathbf{A}\) and a positive semi-definite matrix \(\mathbf{B}\), and just compute \(\lambda_k(\mathbf{A} + \mathrm{bid} \cdot \mathbf{B})\), as illustrated below:</p>

<p><img src="https://alexshtf.github.io/assets/pow_spec_auction_illustration.png" alt="pow_spec_auction_illustration" /></p>

<p>Now we can pass the score to a sigmoid function to obtain a probability that, by design, is nondecreasing in the bid!</p>

<p>You want a more complex model of <em>two</em> numerical inputs \(x\) and \(y\), but it must increase in \(x\) and decrease in \(y\)? No problem! Just encode your other features into an arbitrary \(\mathbf{A}\), a positive semi-definite \(\mathbf{B}\), and a negative semi-definite \(\mathbf{C}\), and compute \(\lambda_k(\mathbf{A} + x \mathbf{B} + y \mathbf{C})\).</p>

<h1 id="training">Training</h1>

<p>But can we actually back-propagate through \(\lambda_k\) to learn our encoder? Well, it turns out we can - everything we need is already in PyTorch. In this post, for the sake of the demonstration, we shall do something extremely simple - use a model whose parameters are just \(\mathbf{A}\), \(\mathbf{B}\) and \(\mathbf{C}\), and learn them from samples of a concave function of two variables.</p>

<p>So here is our simple concave function:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
    <span class="k">return</span> <span class="o">-</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">y</span><span class="o">+</span><span class="mf">0.5</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="o">-</span><span class="n">y</span><span class="o">+</span><span class="mf">0.5</span><span class="p">))</span>
</code></pre></div></div>

<p>Let’s take a look at its graph and its level sets:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">plot_bivariate</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">show_rmse</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
    <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
    <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">projection</span><span class="o">=</span><span class="s">'3d'</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">plot_surface</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z</span><span class="p">)</span>
    <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">contour</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">levels</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">show_rmse</span><span class="p">:</span>
        <span class="n">rmse</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">Z</span> <span class="o">-</span> <span class="n">f</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">))))</span>
        <span class="n">fig</span><span class="p">.</span><span class="n">suptitle</span><span class="p">(</span><span class="sa">f</span><span class="s">'RMSE = </span><span class="si">{</span><span class="n">rmse</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>

<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">300</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">300</span><span class="p">)</span>
<span class="n">X</span><span class="p">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="n">plot_bivariate</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">f</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">))</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_concave_bi.png" alt="pow_spec_concave_bi" /></p>

<p>It indeed appears to be a concave function with some non-trivial shape. The <code class="language-plaintext highlighter-rouge">show_rmse</code> flag will be used later when we plot a fitted model and would like to show its approximation error.</p>

<p>To train a model we shall need two ingredients: training data and a PyTorch module. So let’s generate some training data:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">TensorDataset</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">as_tensor</span>

<span class="k">def</span> <span class="nf">make_trainset</span><span class="p">(</span><span class="n">n_train</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">200</span><span class="p">,</span> <span class="n">train_noise</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">):</span>
    <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">n_train</span><span class="p">)</span>
    <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">n_train</span><span class="p">)</span>
    <span class="n">z</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="o">+</span> <span class="n">train_noise</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n_train</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span>

<span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">z_train</span> <span class="o">=</span> <span class="n">make_trainset</span><span class="p">()</span>
<span class="n">ds</span> <span class="o">=</span> <span class="n">TensorDataset</span><span class="p">(</span>
    <span class="n">as_tensor</span><span class="p">(</span><span class="n">x_train</span><span class="p">),</span> <span class="n">as_tensor</span><span class="p">(</span><span class="n">y_train</span><span class="p">),</span> <span class="n">as_tensor</span><span class="p">(</span><span class="n">z_train</span><span class="p">)</span>
<span class="p">)</span>
</code></pre></div></div>

<p>Here is what it looks like:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">projection</span><span class="o">=</span><span class="s">'3d'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">z_train</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'red'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot_surface</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">f</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">),</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_concave_bi_train.png" alt="pow_spec_concave_bi_train" /></p>

<p>Assume that we know we are learning some real-world phenomenon that must be concave. Then our PyTorch model should use the smallest eigenvalue function. This is just a PyTorch re-implementation of the above - it is slightly more generic than we need now to support other eigenvalue indices, not just the smallest. Just like with SciPy, PyTorch reads only one triangle of each matrix (the lower triangle by default), so we don’t need to explicitly symmetrize them.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.linalg</span> <span class="k">as</span> <span class="n">tla</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="k">class</span> <span class="nc">BivariateSpectral</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">eigval_idx</span> <span class="o">=</span> <span class="n">eigval_idx</span> <span class="o">%</span> <span class="n">dim</span>  <span class="c1"># modulo - to support negative idx
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">A</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">B</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">C</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">))</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
        <span class="c1"># create a batch of matrices, one for each point (x[i], y[i])
</span>        <span class="n">mats</span> <span class="o">=</span> <span class="p">(</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">A</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">B</span> <span class="o">*</span> <span class="n">x</span><span class="p">[...,</span> <span class="bp">None</span><span class="p">,</span> <span class="bp">None</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">C</span> <span class="o">*</span> <span class="n">y</span><span class="p">[...,</span> <span class="bp">None</span><span class="p">,</span> <span class="bp">None</span><span class="p">]</span>
        <span class="p">)</span>

        <span class="c1"># compute the eigenvalues
</span>        <span class="n">eigvals</span> <span class="o">=</span> <span class="n">tla</span><span class="p">.</span><span class="n">eigvalsh</span><span class="p">(</span><span class="n">mats</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">eigvals</span><span class="p">[...,</span> <span class="bp">self</span><span class="p">.</span><span class="n">eigval_idx</span><span class="p">]</span>
</code></pre></div></div>

<p>Alright! Let’s train our smallest eigenvalue model! Below is a pretty standard PyTorch training loop:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
<span class="kn">import</span> <span class="nn">math</span>

<span class="k">def</span> <span class="nf">train_model</span><span class="p">(</span>
        <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">n_epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">500</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span>
    <span class="p">):</span>
    <span class="n">print_every</span> <span class="o">=</span> <span class="n">n_epochs</span> <span class="o">//</span> <span class="mi">10</span>
    <span class="n">dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">ds</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

    <span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span>
    <span class="n">loss_fn</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">MSELoss</span><span class="p">()</span>
    <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">n_epochs</span><span class="p">):</span>
        <span class="n">epoch_loss</span> <span class="o">=</span> <span class="mf">0.</span>
        <span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span> <span class="ow">in</span> <span class="n">dl</span><span class="p">:</span>
            <span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">),</span> <span class="n">z</span><span class="p">)</span>
            <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>

            <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
                <span class="n">epoch_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span>
            <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>

        <span class="k">if</span> <span class="n">epoch</span> <span class="o">==</span> <span class="n">n_epochs</span> <span class="ow">or</span> <span class="n">epoch</span> <span class="o">%</span> <span class="n">print_every</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">train_rmse</span> <span class="o">=</span> <span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">epoch_loss</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">ds</span><span class="p">))</span>
            <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s">, train RMSE: </span><span class="si">{</span><span class="n">train_rmse</span><span class="si">:</span> <span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">model</span>
</code></pre></div></div>

<p>Alright. Let’s try training a concave model with \(3 \times 3\) matrices, and plot it:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span><span class="n">BivariateSpectral</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
    <span class="n">Z</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">as_tensor</span><span class="p">(</span><span class="n">Y</span><span class="p">)).</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">plot_bivariate</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">show_rmse</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<p>The output:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Epoch 50, train RMSE:  0.0952
Epoch 100, train RMSE:  0.0910
Epoch 150, train RMSE:  0.0884
Epoch 200, train RMSE:  0.0866
Epoch 250, train RMSE:  0.0859
Epoch 300, train RMSE:  0.0862
Epoch 350, train RMSE:  0.0860
Epoch 400, train RMSE:  0.0858
Epoch 450, train RMSE:  0.0858
Epoch 500, train RMSE:  0.0856
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_concave_bi_model_3.png" alt="pow_spec_concave_bi_model_3" /></p>

<p>Not bad! Back-propagation appears to be working - the training loss went down, and the plot shows a the model learned something close to the truth.</p>

<p>Now let’s try something interesting - matrices of size \(20 \times 20\). Note, that even <em>one</em> such matrix has more entries than the size of the training set. So such a model is heavily over-parametrized. To make sure we’re converging to the smallest loss I did some tuning, and we need more epochs and a lower learning rate:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span>
    <span class="n">BivariateSpectral</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="n">n_epochs</span><span class="o">=</span><span class="mi">2000</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-4</span>
<span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
    <span class="n">Z</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">as_tensor</span><span class="p">(</span><span class="n">Y</span><span class="p">)).</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">plot_bivariate</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">show_rmse</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Epoch 200, train RMSE:  4.1648
Epoch 400, train RMSE:  2.0389
Epoch 600, train RMSE:  0.5819
Epoch 800, train RMSE:  0.0964
Epoch 1000, train RMSE:  0.0826
Epoch 1200, train RMSE:  0.0811
Epoch 1400, train RMSE:  0.0802
Epoch 1600, train RMSE:  0.0796
Epoch 1800, train RMSE:  0.0793
Epoch 2000, train RMSE:  0.0792
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_concave_bi_model_20.png" alt="pow_spec_concave_bi_model_20" /></p>

<p>Whoa! This is interesting! Even though the model appears over-parametrized - it does not memorize the training set! Why? Well, one possible explanation is that concavity is a strong inductive bias: with noisy samples, a concave function typically cannot interpolate all points, so the best achievable MSE stays positive. Optimization effects and implicit regularization may also play a role. Also, we see that the model is not “crazy” - its shape still resembles the true one, even if it’s a bit different.</p>

<p>So let’s put our conjecture to a test, and try to fit the <em>mid</em> eigenvalue. This is not a function that is concave by design, and as we saw can have many “crazy” shapes.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span>
    <span class="n">BivariateSpectral</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="mi">10</span><span class="p">),</span> <span class="n">n_epochs</span><span class="o">=</span><span class="mi">2000</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-4</span>
<span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
    <span class="n">Z</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">as_tensor</span><span class="p">(</span><span class="n">Y</span><span class="p">)).</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">plot_bivariate</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">show_rmse</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Epoch 200, train RMSE:  0.1979
Epoch 400, train RMSE:  0.1031
Epoch 600, train RMSE:  0.0876
Epoch 800, train RMSE:  0.0822
Epoch 1000, train RMSE:  0.0798
Epoch 1200, train RMSE:  0.0786
Epoch 1400, train RMSE:  0.0778
Epoch 1600, train RMSE:  0.0768
Epoch 1800, train RMSE:  0.0764
Epoch 2000, train RMSE:  0.0760
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_concave_bi_model_20_mid.png" alt="pow_spec_concave_bi_model_20_mid" /></p>

<p>Training error is slightly lower, but apparently does <em>not</em> go down to zero! What happens if we increase the model’s representation power by increasing the size of the matrices? Let’s try the middle eigenvalue of \(40 \times 40\) matrices:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span>
    <span class="n">BivariateSpectral</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">40</span><span class="p">,</span> <span class="n">eigval_idx</span><span class="o">=</span><span class="mi">20</span><span class="p">),</span> <span class="n">n_epochs</span><span class="o">=</span><span class="mi">2000</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-4</span>
<span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
    <span class="n">Z</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">as_tensor</span><span class="p">(</span><span class="n">Y</span><span class="p">)).</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">plot_bivariate</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">show_rmse</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Epoch 200, train RMSE:  0.1032
Epoch 400, train RMSE:  0.0710
Epoch 600, train RMSE:  0.0668
Epoch 800, train RMSE:  0.0650
Epoch 1000, train RMSE:  0.0639
Epoch 1200, train RMSE:  0.0624
Epoch 1400, train RMSE:  0.0617
Epoch 1600, train RMSE:  0.0609
Epoch 1800, train RMSE:  0.0605
Epoch 2000, train RMSE:  0.0598
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/pow_spec_concave_bi_model_40_mid.png" alt="pow_spec_concave_bi_model_40_mid" /></p>

<p>Training error indeed goes down, so our model gained expressive power. The model has 2460 parameters, much more than our 200 training samples. Yet, the training error does not go to zero, the test error does not explode, and the plot still resembles the true function! So perhaps there is some other property these functions have “by design”, even if we don’t impose shape by using extreme eigenvalues?</p>

<p>Well, there are apparently some interesting properties of these eigenvalue functions worth investigating. So now it would be a good time to do a recap and summarise what we saw, because it appears to be quite a lot to grasp.</p>

<h1 id="summary">Summary</h1>

<p>We saw here something that most linear algebra and optimization researchers have known for a long time - eigenvalues of symmetric matrices can be used to model functions of various desired shapes. If we know that the real-world phenomenon we aim to learn has some shape, it might be a good idea to model it that way <em>by design</em>. It makes our model sane at inference, because it behaves as we would expect, and in shape-constrained settings it can also improve training behavior (at least in this toy example).</p>

<p>Some bibliographic notes. The idea of monotonicity in neural networks was studied before. One prominent line of work is that of deep lattice networks<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>, and mixing monotonicity with convexity and concavity was studied by Runje and Shankaranarayana<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup>. And despite the fact that what I showed here has been known for a long time, the idea of applying this knowledge to construct a learnable model from generic eigenvalue functions is, to the best of my knowledge, quite recent, and appeared in the 2025 paper by Cook et al.<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup>. Their work also proved these models to be universal approximators. Universal approximation is an asymptotic expressivity statement: it does not guarantee interpolation at a fixed size, nor does it say anything about optimization. So it does not contradict our observations here. Moreover, universality is not that important, especially when we care about a specific shape. In fact, in these cases we actually do <em>not</em> want a universal family - we want a family respecting our shape constraints.</p>

<p>Now with this toolbox in mind, we can try to explore other things:</p>

<ul>
  <li>How do we build a model that learns positive semidefinite matrices in PyTorch to model increasing / decreasing functions?</li>
  <li>Can we mix and match? For example - build a model that is <em>by design</em> increasing in \(x, z\), decreasing in \(y\), and jointly concave in \((x, y)\)?</li>
  <li>Do we need fully dense matrices? Maybe low-rank / sparse / banded matrices already have enough expressive power?</li>
  <li>What exactly are the properties of eigenvalue functions? What is their representation power? How do eigenvalue functions compare to just plain old ReLU MLPs?</li>
  <li>What are the theoretical interpretations of these eigenvalue functions? Perhaps we can think of them as deep neural networks? Or maybe something else?</li>
</ul>

<p>In the next posts we shall explore some of these aspects, and maybe more. So stay tuned!</p>

<h1 id="references">References</h1>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:1" role="doc-endnote">
      <p>You, Seungil, David Ding, Kevin Canini, Jan Pfeifer, and Maya Gupta. “Deep lattice networks and partial monotonic functions.” <em>Advances in neural information processing systems</em> 30 (2017). <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>Runje, Davor, and Sharath M. Shankaranarayana. “Constrained monotonic neural networks.” In <em>International Conference on Machine Learning</em>, pp. 29338-29353. PMLR, 2023. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:3" role="doc-endnote">
      <p>Cook, Patrick, Danny Jammooa, Morten Hjorth-Jensen, Daniel D. Lee, and Dean Lee. “Parametric matrix models.” <em>Nature Communications</em> 16, no. 1 (2025): 5929. <a href="#fnref:3" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><category term="machine learning" /><category term="regression" /><category term="eigenvalue models" /><category term="spectral methods" /><category term="shape constraints" /><category term="monotonicity" /><category term="convex regression" /><category term="concave regression" /><category term="pytorch" /><summary type="html"><![CDATA[Eigenvalues as neurons: represent nonlinear models as the k-th eigenvalue of a learned symmetric matrix pencil. Explore monotonicity/convexity properties and train simple spectral models.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="https://alexshtf.github.io/assets/pow_spec_auction_illustration.png" /><media:content medium="image" url="https://alexshtf.github.io/assets/pow_spec_auction_illustration.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Paying attention to feature distribution alignment</title><link href="https://alexshtf.github.io/2025/08/19/Orthogonality.html" rel="alternate" type="text/html" title="Paying attention to feature distribution alignment" /><published>2025-08-19T00:00:00+00:00</published><updated>2025-08-19T00:00:00+00:00</updated><id>https://alexshtf.github.io/2025/08/19/Orthogonality</id><content type="html" xml:base="https://alexshtf.github.io/2025/08/19/Orthogonality.html"><![CDATA[<h1 id="intro">Intro</h1>

<p>Yes, I’m making a joke of the tendency to put the words “attention” and “alignment” in any ML paper 😎. Now let’s see how this provocative title is related to our adventures in the land of polynomial features.</p>

<p>The Legendre polynomial basis serverd us well in recent posts about polynomial features. One interesting thing we saw in the series is that its <em>orthogonality</em> is, in some sense <em>informativeness</em>.  This is because it orthogonal bases produce features, and hence each basis function, in some sense, carries information that the other basis functions do not. Of course, we all like informative features. So I’d like to devote this post to studying it a bit deeper.</p>

<p>But the Legendre basis is informative in this sense only if our features are uniformly distributed. But real data isn’t uniformly distributed. So in this post I’d like to discuss two ways in which can deal with this practical issue. The associated notebook for reproducing all results is <a href="https://github.com/alexshtf/alexshtf.github.io/blob/master/assets/orthogonality_informativeness.ipynb">here</a>.</p>

<h1 id="orthogonality--informativeness">Orthogonality = informativeness</h1>

<p>So that we all are on the same page, let’s recall why orthogonal bases produce uncorrelated features.  Recall, the two polynomials \(P_i\) and \(P_j\) defined on \([-1, 1]\) are orthogonal if</p>

\[\langle P_i, P_j \rangle = \int_{-1}^1 P_i(x) P_j(x) dx = 0,\]

<p>just like two orthogonal vectors - their inner product is zero. But here the inner product is an integral rather than a sum. But an integral is also an expectation, and empirical averages approximate expectations. So if our data points \(x_1, \dots, x_n\) are approximately uniform in \([-1, 1]\), then</p>

\[0 = \int_{-1}^1 P_i(x) P_j(x) dx \sim \frac{2}{n} \sum_{k=1}^n P_i(x_k) P_j(x_k).\]

<p>Hence, any column in the data-set the model observes during training is <em>uncorelated</em> to the other columns coming from the same orthogonal basis, and thus in some sense carry information that the other columns do not have.</p>

<p>Informativeness, of course, is not the only important trait of a good basis for non-linear features. In fact, even the <em>norms</em> of the orthogonal basis are important,  as well as other traits. I advise reading the well-written and enlightening <a href="https://arxiv.org/abs/1903.09139">paper</a><sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">1</a></sup> by Muthukumar et. al for more on this. But here, in this post, we focus mainly on orthogonality as informativeness.</p>

<h1 id="weighted-orthogonality">Weighted orthogonality</h1>

<p>Going back to our linear algebra classes, inner products come in many forms. Given a vector of weights \(\mathbf{w} &gt; 0\), we can define a weighted inner product:</p>

\[\langle \mathbf{x}, \mathbf{y} \rangle_{\mathbf{w}} = \sum_{i=1}^n x_i y_i w_i.\]

<p>The contribution of every two components at index \(i\) is weighted by the weight \(w_i\).</p>

<p>Similarly, given a  <em>weight function</em> \(w(x)&gt;0\)  integrable over the domain \(D\), we can define a weighted inner product between two functions on that domain:</p>

\[\langle f, g\rangle_w = \int_{D} f(x)g(x)w(x)dx\]

<p>The contribution at each point \(x\) to the integral is weighted by the weight \(w(x)\). The Legendre basis, for example, is orthogonal on \(D = [-1, 1]\) according to the <em>uniform weight</em> \(w(x) = 1\).</p>

<p>Now let’s see what does it mean in terms of informativeness. Suppose without loss of generality that \(w(x)\) is normalized such that  \(\int_{D} w(x) = 1\). If it’s not, we can always divide it by its integral. Now, it can be thought of as PDF of some probability distribution  over \(x\), and therefore the inner product is just an expectation. Therefore if our data points \(x_1, \dots, x_n\) come from that distribution, then</p>

\[\langle f, g \rangle_w = \mathbb{E}_x \left[ f(x) g(x) \right] \sim \frac{1}{n} \sum_{i=1}^n f(x_i) g(x_i).\]

<p>Consequently, if \(f\) and \(g\) are orthogonal w.r.t our inner product, the two features generated by \(f(x)\) and \(g(x)\) are uncorrlated, or informative.</p>

<p>So ideally we would like to devise orthogonal bases w.r.t the PDF of our data distribution. The differential equations community has been designing orthogonal bases w.r.t various weights for a long time, and have come up with plenty of methods and an enormous body of literature. But there are two extremely simple tricks we can adopt from them for ML, which I’d like to discuss in this pose. Both are also discussed in the context of approximation theory and differential equations in the recent survey paper by Shen and Wang<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">2</a></sup>. One of them appears extremely intuitive, easy to implement, and is indeed useful in practice. The other one requires some careful math before coding, harder to apply in practice, and I’d like to present it and discuss when I believe it may be useful to.</p>

<h1 id="the-mapping-trick">The mapping trick</h1>

<p>Let’s focus on the Legendre basis that is orthogonal on \([-1, 1]\). Instead of min-max scaling, which we did in previous posts, suppose we use some invertible and differentiable function \(\phi: D \to [-1, 1]\) that maps our feature from its original domain. In terms of the raw feature, our basis functions are</p>

\[Q_i(x) = P_i(\phi(x)).\]

<p>Are they orthogonal? Well, in some sense, they are. Using the change of variable \(y = \phi(x)\), we know from high-school calculus that :</p>

\[0 = \int_{-1}^1 P_i(y) P_j(y) dy = \int_{D} P_i(\phi(x)) P_j(\phi(x)) \phi'(x) dx = \langle Q_i, Q_j \rangle_{\phi'}\]

<p>The conclusion is simple - mapping with \(\phi\) results in orthogonal functions weighted by \(\phi'\). In particular, if \(\phi\) is a CDF of some distribution, then using the basis \(Q_0, Q_1, ...\) will result in uncorrelated features!</p>

<p>So what mapping should we use? If we know or can estimate the CDF \(W\) of our feature, we should use</p>

\[x \to 2W(x) - 1.\]

<p>Indeed, it maps to \([-1, 1]\), and the derivative of this mapping is twice the PDF. Just what we need.</p>

<p>We can, of course, attempt to do mathematical acrobatics to extend this to non-differentiable CDF functions \(W\),  but this is not a paper, just a blog post. In fact, we will use a non-differentiable CDF in our code, without doing the math. This idea aligns with our intuition at the intro - mapping using a “uniformizing” transformation before computing Legendre polynomials produces an orthogonal basis w.r.t the original raw feature.</p>

<h1 id="a-small-simulation">A small simulation</h1>

<p>We shall sample data from some distributions, and use the above mapping to transform it before computing the Legendre vandermonde matrix. Then, we shall inspect the correlation between columns. Here is a function that accepts a <code class="language-plaintext highlighter-rouge">scipy.stats</code> distribution object, and computes the correlation matrix:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="k">def</span> <span class="nf">simulate_correlation</span><span class="p">(</span><span class="n">dist</span><span class="p">,</span> <span class="n">degree</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">n_samples</span><span class="o">=</span><span class="mi">10000</span><span class="p">):</span>
    <span class="n">samples</span> <span class="o">=</span> <span class="n">dist</span><span class="p">.</span><span class="n">rvs</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
    <span class="n">mapped</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">dist</span><span class="p">.</span><span class="n">cdf</span><span class="p">(</span><span class="n">samples</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
    <span class="n">vander</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">(</span><span class="n">mapped</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">corrcoef</span><span class="p">(</span><span class="n">vander</span><span class="p">.</span><span class="n">T</span><span class="p">)</span>
</code></pre></div></div>

<p>Pretty straightforward - sample, transform, compute Legenre basis functions for each mapped sample, and then correlation between any two resulting features. So let’s try doing some plots. Here is a simulation of our data having the standard Normal distirbution:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">scipy.stats</span>

<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">simulate_correlation</span><span class="p">(</span><span class="n">scipy</span><span class="p">.</span><span class="n">stats</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/orthogonality_norm_std.png" alt="orthogonality_norm_std" /></p>

<p>We see a diagonal of ones, and values close to zero outside the diagonal. Well, except for the first row and column - their are the constant function 1, so it has no variance, and thus no covariance. But that’s OK - in models we typically have a separate bias term, and do not include the constant function in our basis.</p>

<p>What about some non-standard normal?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">simulate_correlation</span><span class="p">(</span><span class="n">scipy</span><span class="p">.</span><span class="n">stats</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">10</span><span class="p">)))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/orthogonality_norm_nonstd.png" alt="orthogonality_norm_std" /></p>

<p>Similar - pairs of features are practically uncorrelated. Their correlation is close to zero. How about some Gamma distribution?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">simulate_correlation</span><span class="p">(</span><span class="n">scipy</span><span class="p">.</span><span class="n">stats</span><span class="p">.</span><span class="n">gamma</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/orthogonality_norm_gamma.png" alt="orthogonality_norm_std" /></p>

<p>Neat! So if we know our data distribution, we can generate informative non-linear features by composing our CDF-based mapping with the Legendre basis.</p>

<h1 id="the-mapping-trick-in-practice">The mapping trick in practice?</h1>

<p>In practice we don’t know the data distribution of each column. We can estimate it by various means, such as fitting to some candidate distributions using SciPy. But we can also do another neat approximation - we can use Scikit-Learn’s <code class="language-plaintext highlighter-rouge">QuantileTransformer</code>, and it does approximately what we desire. It approximates the CDF, and maps raw features to quantiles using the CDF. We will just have to add one small step to map it from \([0, 1]\) to \([-1, 1]\). Note, that its approximate CDF is non-differentiable - it’s a step function. We haven’t shown anything for a non-differentiable CDF used as a mapping. This is where theory is just a good guide.</p>

<p>Here is a simple pipeline for fitting a linear regression model onto our orthogonal Legendre features, using our previously developed <code class="language-plaintext highlighter-rouge">LegendreScalarPolynomialFeatures</code> from the last post. This class doesn’t do anything special - just takes raw feature columns, and computes the Legendre vandermonde matrix.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.preprocessing</span> <span class="kn">import</span> <span class="n">QuantileTransformer</span><span class="p">,</span> <span class="n">FunctionTransformer</span>
<span class="kn">from</span> <span class="nn">sklearn.pipeline</span> <span class="kn">import</span> <span class="n">Pipeline</span>

<span class="k">def</span> <span class="nf">ortho_features_pipeline</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="mi">8</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">Pipeline</span><span class="p">([</span>
        <span class="p">(</span><span class="s">'quantile-transformer'</span><span class="p">,</span> <span class="n">QuantileTransformer</span><span class="p">()),</span>
        <span class="p">(</span><span class="s">'post-mapper'</span><span class="p">,</span> <span class="n">FunctionTransformer</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="mi">2</span><span class="o">*</span><span class="n">x</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'polyfeats'</span><span class="p">,</span> <span class="n">LegendreScalarPolynomialFeatures</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="n">degree</span><span class="p">)),</span>
    <span class="p">])</span>
</code></pre></div></div>

<p>Let’s try applying it to some simulated data and see if we get uncorrelated features. We shall generate two data columns with a Normal and a Gamma distribution, compute features using our pipeline, and plot their correlation matrix:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># two columns - Normal and Gamma
</span><span class="n">sim_data</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span>
    <span class="n">scipy</span><span class="p">.</span><span class="n">stats</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">).</span><span class="n">rvs</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1</span><span class="p">)),</span>
    <span class="n">scipy</span><span class="p">.</span><span class="n">stats</span><span class="p">.</span><span class="n">gamma</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="n">rvs</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1</span><span class="p">)),</span>
<span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

<span class="c1"># features
</span><span class="n">features</span> <span class="o">=</span> <span class="n">ortho_features_pipeline</span><span class="p">().</span><span class="n">fit_transform</span><span class="p">(</span><span class="n">sim_data</span><span class="p">)</span>

<span class="c1"># plot correlation matrix
</span><span class="n">coef_mat</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">corrcoef</span><span class="p">(</span><span class="n">features</span><span class="p">.</span><span class="n">T</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">coef_mat</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/orthogonality_test_pipeline.png" alt="orthogonality_test_pipeline" /></p>

<p>Nice! Now let’s try training a linear regression model with our new pipeline.</p>

<h1 id="testing-on-real-data">Testing on real data</h1>

<p>Let’s load our beloved california housing dataset and see what we have achieved. Let’s load it, and apply the log transformation we always do to the skewed columns:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>

<span class="n">train_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s">"sample_data/california_housing_train.csv"</span><span class="p">)</span>
<span class="n">test_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s">"sample_data/california_housing_test.csv"</span><span class="p">)</span>

<span class="n">X_train</span> <span class="o">=</span> <span class="n">train_df</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="s">"median_house_value"</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">y_train</span> <span class="o">=</span> <span class="n">train_df</span><span class="p">[</span><span class="s">"median_house_value"</span><span class="p">]</span>

<span class="n">X_test</span> <span class="o">=</span> <span class="n">test_df</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="s">"median_house_value"</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">y_test</span> <span class="o">=</span> <span class="n">test_df</span><span class="p">[</span><span class="s">"median_house_value"</span><span class="p">]</span>

<span class="n">skewed_columns</span> <span class="o">=</span> <span class="p">[</span><span class="s">'total_rooms'</span><span class="p">,</span> <span class="s">'total_bedrooms'</span><span class="p">,</span> <span class="s">'population'</span><span class="p">,</span> <span class="s">'households'</span><span class="p">]</span>
<span class="n">X_train</span><span class="p">.</span><span class="n">loc</span><span class="p">[:,</span> <span class="n">skewed_columns</span><span class="p">]</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">[</span><span class="n">skewed_columns</span><span class="p">].</span><span class="nb">apply</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">)</span>
<span class="n">X_test</span><span class="p">.</span><span class="n">loc</span><span class="p">[:,</span> <span class="n">skewed_columns</span><span class="p">]</span> <span class="o">=</span> <span class="n">X_test</span><span class="p">[</span><span class="n">skewed_columns</span><span class="p">].</span><span class="nb">apply</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">)</span>
</code></pre></div></div>

<p>Now let’s fit a linear regression model and see that it works:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">LinearRegression</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">root_mean_squared_error</span>

<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">([</span>
    <span class="p">(</span><span class="s">'ortho-features'</span><span class="p">,</span> <span class="n">ortho_features_pipeline</span><span class="p">()),</span>
    <span class="p">(</span><span class="s">'lin-reg'</span><span class="p">,</span> <span class="n">LinearRegression</span><span class="p">()),</span>
<span class="p">])</span>
<span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>62672.703496184964
</code></pre></div></div>

<p>Let’s compare to our min-max scaling strategy we tried in previous posts:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.preprocessing</span> <span class="kn">import</span> <span class="n">MinMaxScaler</span>

<span class="k">def</span> <span class="nf">minmax_legendre_features</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="mi">8</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">Pipeline</span><span class="p">([</span>
        <span class="p">(</span><span class="s">'scaler'</span><span class="p">,</span> <span class="n">MinMaxScaler</span><span class="p">(</span><span class="n">feature_range</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">clip</span><span class="o">=</span><span class="bp">True</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'polyfeats'</span><span class="p">,</span> <span class="n">LegendreScalarPolynomialFeatures</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="n">degree</span><span class="p">)),</span>
    <span class="p">])</span>

<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">([</span>
    <span class="p">(</span><span class="s">'minmax-legendre'</span><span class="p">,</span> <span class="n">minmax_legendre_features</span><span class="p">()),</span>
    <span class="p">(</span><span class="s">'lin-reg'</span><span class="p">,</span> <span class="n">LinearRegression</span><span class="p">()),</span>
<span class="p">])</span>
<span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>63426.15965332127
</code></pre></div></div>

<p>So at least for the default Legendre polynomial degree, the approximately orthogonal features appear to work quite well. What Let’s try to compare several degrees:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">deg</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">22</span><span class="p">,</span> <span class="mi">2</span><span class="p">):</span>
    <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">([</span>
        <span class="p">(</span><span class="s">'minmax-legendre'</span><span class="p">,</span> <span class="n">minmax_legendre_features</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="n">deg</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'lin-reg'</span><span class="p">,</span> <span class="n">LinearRegression</span><span class="p">()),</span>
    <span class="p">])</span>
    <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
    <span class="n">minmax_rmse</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>

    <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">([</span>
        <span class="p">(</span><span class="s">'ortho-features'</span><span class="p">,</span> <span class="n">ortho_features_pipeline</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="n">deg</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'lin-reg'</span><span class="p">,</span> <span class="n">LinearRegression</span><span class="p">()),</span>
    <span class="p">])</span>
    <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
    <span class="n">ortho_rmse</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>

    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Degree = </span><span class="si">{</span><span class="n">deg</span><span class="si">}</span><span class="s">, minmax_rmse = </span><span class="si">{</span><span class="n">minmax_rmse</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">, ortho_rmse = </span><span class="si">{</span><span class="n">ortho_rmse</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
</code></pre></div></div>
<p>The output is:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Degree = 1, minmax_rmse = 67775.24, ortho_rmse = 74588.32
Degree = 3, minmax_rmse = 65137.16, ortho_rmse = 67551.46
Degree = 5, minmax_rmse = 64054.74, ortho_rmse = 65010.42
Degree = 7, minmax_rmse = 63523.41, ortho_rmse = 63297.67
Degree = 9, minmax_rmse = 63440.02, ortho_rmse = 61606.44
Degree = 11, minmax_rmse = 63305.14, ortho_rmse = 61438.60
Degree = 13, minmax_rmse = 65575.86, ortho_rmse = 61237.12
Degree = 15, minmax_rmse = 175047.47, ortho_rmse = 60611.78
Degree = 17, minmax_rmse = 175270.39, ortho_rmse = 60680.52
Degree = 19, minmax_rmse = 781416.93, ortho_rmse = 60111.46
</code></pre></div></div>

<p>At least on this dataset, the truly orthogonal features appear to be better. Note, how the error of the naively-scaled basis rapidly increases - we’re losing informativeness. Of course, we know that if we crank-up the degree to 10,000, we will observe double descent, and all the other nice stuff we saw in previous posts. But that’s not the point of this post.</p>

<p>What about Ridge regression? Maybe it’s somewhat different? It should be - regularization should “tame” the behavior of the min-max scaled features. But are the orghogonal features still better?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">RidgeCV</span>

<span class="k">for</span> <span class="n">deg</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">22</span><span class="p">,</span> <span class="mi">2</span><span class="p">):</span>
    <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">([</span>
        <span class="p">(</span><span class="s">'minmax-legendre'</span><span class="p">,</span> <span class="n">minmax_legendre_features</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="n">deg</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'lin-reg'</span><span class="p">,</span> <span class="n">RidgeCV</span><span class="p">()),</span>
    <span class="p">])</span>
    <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
    <span class="n">minmax_rmse</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>

    <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">([</span>
        <span class="p">(</span><span class="s">'ortho-features'</span><span class="p">,</span> <span class="n">ortho_features_pipeline</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="n">deg</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'lin-reg'</span><span class="p">,</span> <span class="n">RidgeCV</span><span class="p">()),</span>
    <span class="p">])</span>
    <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
    <span class="n">ortho_rmse</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>

    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Degree = </span><span class="si">{</span><span class="n">deg</span><span class="si">}</span><span class="s">, minmax_rmse = </span><span class="si">{</span><span class="n">minmax_rmse</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">, ortho_rmse = </span><span class="si">{</span><span class="n">ortho_rmse</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
</code></pre></div></div>
<p>The output is</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Degree = 1, minmax_rmse = 67771.06, ortho_rmse = 74482.30
Degree = 3, minmax_rmse = 65121.63, ortho_rmse = 67485.67
Degree = 5, minmax_rmse = 64077.37, ortho_rmse = 64971.80
Degree = 7, minmax_rmse = 63541.46, ortho_rmse = 63380.61
Degree = 9, minmax_rmse = 63394.64, ortho_rmse = 61650.67
Degree = 11, minmax_rmse = 62889.19, ortho_rmse = 61386.24
Degree = 13, minmax_rmse = 62305.72, ortho_rmse = 61192.17
Degree = 15, minmax_rmse = 62276.36, ortho_rmse = 60649.21
Degree = 17, minmax_rmse = 62045.29, ortho_rmse = 60677.81
Degree = 19, minmax_rmse = 61883.94, ortho_rmse = 60165.78
Degree = 21, minmax_rmse = 61802.18, ortho_rmse = 59526.32
</code></pre></div></div>

<p>It appears they are. Except for the initial low-degree polynomials, the orthogonal features we obtained by composing the the (empirical) CDF appear to outperform naive min-max scaling.</p>

<p>Obviously, in practice the degree is a tunable parameter. Its performance should be tested on a validation set, and the best configuraion should then be employed on the test set for a final evaluation. But if the same phenomenon happens across many degrees, then I believe it’s convincing enough.</p>

<p>This is not a paper, and this is not a thorough benchmark on a variety of data-sets. This is not the point - the point is that even though data speak, theory guides. And its guidance can be oftentimes useful, if you listen carefully.</p>

<h1 id="the-multiplication-by-one-trick">The multiplication by one trick</h1>

<p>Many deep results in mathematics are derived by multiplying by one, but this one has to be chosen wisely. So let’s choose our one wisely. We have:</p>

\[f(x) g(x) w(x) = \frac{f(x)}{u(x)} \frac{g(x)}{u(x)} w(x) u^2(x)\]

<p>This looks a bit hairy, but all we did was multiply and divide by \(u(x) \neq 0\). Consequently, if \(\langle f, g \rangle_w = 0\), then</p>

\[\left\langle \frac{f}{u}, \frac{g}{u} \right \rangle_{w u^2} = 0.\]

<p>So taking a basis \(Q_0, Q_1, \dots\) orthogonal w.r.t \(w\), and dividing all functions by a given non-zero function \(u(x)\), we obtain a basis orthogonal w.r.t a new weight \(wu^2\). This reminds the well-known <a href="https://en.wikipedia.org/wiki/Importance_sampling">importance sampling</a> Monte-Carlo algorithms, where we can choose a more convenient distribution to sample from by multiplying and dividing by a well-chosen density.</p>

<p>We can choose the weight we desire by carefully choosing \(u\), right?  Easy peasy - just divide all functions by \(u\), and we’re done! But it’s not that simple, because the new basis</p>

\[\frac{Q_0}{u}, \frac{Q_1}{u}, \dots\]

<p>might have a radically different representation power. Especially if we take only a finite number of basis functions, like we do in machine learning.</p>

<p>For example, if our original basis was bounded, but the ratio \(\frac{1}{u(x)}\) isn’t, then the new basis suddenly consists of unbounded functions that may grow to infinity.  Alternatively, if this ratio decays towards zero, then the new basis functions also decay towards zero.</p>

<p>Let’s look at a concrete case to understand the issue. Suppose our feature \(x\) is the total time the user spent on our website in the last month. Probably, the effect of this feature on user behavior “flattens” at some point - users who spent 5 minutes may be different than the ones that spent 10 minutes, but those that spent 10 hours may not be that different from those that spent 20. We certainly would <em>not</em> want a function that grows to infinity as \(x\) grows!</p>

<p>The other side of our observation is that this idea of dividing by \(u\) gives us some degree of control. You can mix it with the mapping approach to design a family of feature orthogonal w.r.t the weight of your choice, but also have the desired properties you want, such as growth or decay. These desired properties are the <em>inductive bias</em> you bake into your model to help it generalize better to unseen data, or conform to some regulatory or safety constraints.</p>

<p>This mixing and matching sounds easy, but it may not be so. Let’s look at an example, just to give you the feeling. Suppose our total time on the website has a distribution with CDF \(U\), and we want features that decay towards zero, because the effect of this total time eventually flattens out. If we take some mapping \(T: [0, \infty) \to [-1, 1]\), we can construct a basis from Legendre polynomials:</p>

\[P_0(T(x)), P_1(T(x)), P_2(T(x)), ...\]

<p>This basis will be orthogonal w.r.t the weight \(T'(x)\), according to what we saw about the mapping trick. Now, to make these functions decay towards zero, we may want to divide them by some unbounded function, such as \(u(x)=1 + x\), and get a new basis:</p>

\[\underbrace{\frac{P_0(T(x))}{1+x}}_{Q_0(x)}, \underbrace{\frac{P_1(T(x))}{1+x}}_{Q_1(x)}, ...\]

<p>According to what we saw now, \(Q_0, Q_1, ...\) is orthogonal w.r.t the weight function \(w(x) = T'(x) (1 + x)^2\).</p>

<p>So now, we want to design the mapping \(T\) such that the weight aligns with the data distribution, to get informative features from our basis. This means we want the weight to be proportional to the PDF:
\(T'(x) (1 + x^2) = a U'(x), \qquad a &gt; 0.\)</p>

<p>Equivalently</p>

\[T(x) = a\int \frac{U'(x)}{1+x^2}dx + b.\]

<p>Here, we will chose such that \(T\) maps to the interval \([-1, 1]\).  This ensures that our features are both orthogonal w.r.t the right weight function, the PDF of the data distribution, and also decay towards zero, because we divided bounded Legendre polynomials by \(u(x) = 1 + x\). Of course, there may be many decay inducing functions, such as \(u(x)=\exp(x)\). This is where your “feature engineering” voodoo kicks in.</p>

<p>But now we saw why this approach is less useful in practice - complexity. It’s hard to choose the “right” decay function. You’ll also have to estimate some distribution \(U\) from the data, and then compute an integral. Maybe this \(U\) is just a non-differentiable empirical CDF - and you’ll need to do some acrobatics to deal with it. And of course you’ll have to know the lower and upper bound on the integral, so that you can chose  \(a\) and \(b\).</p>

<p>I haven’t done all this math, so I don’t know the extent to which all of these are easily solvable. But it seems like a lot of trouble! Definitely not something we’re used to doing at work, and much harder than just stacking a <code class="language-plaintext highlighter-rouge">QuantileTransformer</code> before your favorite orthogonal polynomial basis. Thus, unless you’re absolutely sure you have to build these inductive biases into your model, i.e. safety or regulations, then I wouldn’t go in this direction. But I hope you appreciate the fact that you can actually do it with enough effort.</p>

<p>Of course, because of this complexity, conducting an experiment on our data-set is out of the scope of this post, and I’ll leave this part as is - just theoretical.</p>

<h1 id="summary">Summary</h1>

<p>Now it appears clear why the provocative title fits this post - we indeed paid close attention to the alignment between our non-linear features and the data distribution. This alignment is manifested in the form of the weight of the inner-product space our basis functions live in. I will repeat myself here as well - data speaks, but theory guides, if you care to listen.</p>

<p>Note, that the differential equations community puts a different emphasis on these orthogonal bases. The weight function plays a different role, and the “inductive bias” we talked about in the multiplication trick is less relevant - they care about approximation power, and less about inductive biases. Both are related, but not exactly the same. So reading the paper by Shen and Wang may be interesting, but focuses on different things.</p>

<h1 id="references">References</h1>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:2" role="doc-endnote">
      <p>Muthukumar, V., Vodrahalli, K., Subramanian, V. and Sahai, A., 2019. Harmless interpolation of noisy data in regression. <em>arXiv preprint arXiv:1903.09139</em>. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:1" role="doc-endnote">
      <p>Shen, J. and Wang, L.L., 2009. Some recent advances on spectral methods for unbounded domains. <em>Communications in computational physics</em>, <em>5</em>(2-4), pp.195-241. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><category term="machine learning" /><category term="feature engineering" /><category term="polynomials" /><category term="polynomial regression" /><category term="Legendre polynomials" /><category term="orthogonal polynomials" /><category term="feature scaling" /><category term="quantile transformer" /><category term="correlation analysis" /><category term="California housing dataset" /><summary type="html"><![CDATA[Orthogonal polynomial features are only uncorrelated when the feature distribution matches the basis weight. Use CDF/quantile transforms to align distributions and get more informative Legendre features.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="https://alexshtf.github.io/assets/orthogonality_test_pipeline.png" /><media:content medium="image" url="https://alexshtf.github.io/assets/orthogonality_test_pipeline.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Off with the polynomial’s tail!</title><link href="https://alexshtf.github.io/2025/04/17/Polynomial-Pruning.html" rel="alternate" type="text/html" title="Off with the polynomial’s tail!" /><published>2025-04-17T00:00:00+00:00</published><updated>2025-04-17T00:00:00+00:00</updated><id>https://alexshtf.github.io/2025/04/17/Polynomial-Pruning</id><content type="html" xml:base="https://alexshtf.github.io/2025/04/17/Polynomial-Pruning.html"><![CDATA[<h1 id="intro">Intro</h1>

<p>Last time we did a small curve fitting exercise - we fit high degree polynomials represented in the Legendre basis to a curve, and observe the celebrated “double descent” phenomenon: after crossing the memorization threshold, the generalization error improves as the degree increases. Then, we took a closer look at what happens when we fit a high degree of Legendre polynomial to try to explain this double descent. We observed that the Legendre basis polynomials oscilate, and therefore behave like a kind of a “frequency domain”, and conjectured that coefficients of lower degree functions are responsible for the overall shape of the fit curve, whereas coefficients of higher degree functions model rapid fluctuations that fit the deviation from the overall shape to the noisy data. Then, looked at “pruned” polynomials obtained by using only a few initial coefficients and discarding the rest, and indeed saw that the pruned function captures the overall shape.</p>

<p>In this post we shall study this phenomenon not for fitting a curve, but for fitting a regression model to our favorite dataset in this blog - the California Housing dataset. When writing this post I learned something surprising and new, and I hope to surprise you as well. This may not be a new state of the art method, but it is a surprising insight, heavily inspired by a short online <a href="https://x.com/bremen79/status/1907132804313272371">discussion</a>  with Prof. Francesco Orabona about what does it mean for a model to be “simple”. In fact, this discussion is what led me to write this post.</p>

<p>As always, the code can be found in a <a href="https://github.com/alexshtf/alexshtf.github.io/blob/master/assets/calirofnia_housing_legendre_pruning.ipynb">notebook</a> you can deploy to Colab and play with yourself. There will be no formulas or math in this post - mostly code and plots. So let’s get started!</p>

<h1 id="double-descent-with-california-housing">Double-descent with california housing</h1>

<p>Let’s start by preparing the data. We do some standard stuff, nothing fancy. Load the data:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>

<span class="n">train_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s">"sample_data/california_housing_train.csv"</span><span class="p">)</span>
<span class="n">test_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s">"sample_data/california_housing_test.csv"</span><span class="p">)</span>
<span class="n">train_df</span><span class="p">.</span><span class="n">head</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code> longitude  latitude  housing_median_age  total_rooms  total_bedrooms  population  households  median_income  median_house_value
   -118.07     33.81                22.0       2711.0           352.0      1305.0       368.0         8.5407            398800.0
   -117.63     33.50                12.0       3619.0           536.0      1506.0       492.0         7.2013            353600.0
   -117.09     32.57                17.0        444.0            83.0       357.0        87.0         5.1478            138900.0
   -117.16     32.81                34.0       2275.0           375.0      1021.0       379.0         3.6371            176300.0
   -118.07     34.17                36.0       2415.0           394.0      1215.0       413.0         5.5418            326100.0
</code></pre></div></div>

<p>Split the loaded training set into a training and validation set:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>

<span class="n">train_df</span><span class="p">,</span> <span class="n">valid_df</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">train_df</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span>
</code></pre></div></div>

<p>Separate the prediction target column to a separate variable:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X_train</span> <span class="o">=</span> <span class="n">train_df</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="s">"median_house_value"</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">y_train</span> <span class="o">=</span> <span class="n">train_df</span><span class="p">[</span><span class="s">"median_house_value"</span><span class="p">]</span>

<span class="n">X_valid</span> <span class="o">=</span> <span class="n">valid_df</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="s">"median_house_value"</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">y_valid</span> <span class="o">=</span> <span class="n">valid_df</span><span class="p">[</span><span class="s">"median_house_value"</span><span class="p">]</span>

<span class="n">X_test</span> <span class="o">=</span> <span class="n">test_df</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="s">"median_house_value"</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">y_test</span> <span class="o">=</span> <span class="n">test_df</span><span class="p">[</span><span class="s">"median_house_value"</span><span class="p">]</span>
</code></pre></div></div>

<p>Finally, one can observe that there are some numerical columns ith extremely skewed distributions which are more sane after a log transformation:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="n">skewed_columns</span> <span class="o">=</span> <span class="p">[</span><span class="s">'total_rooms'</span><span class="p">,</span> <span class="s">'total_bedrooms'</span><span class="p">,</span> <span class="s">'population'</span><span class="p">,</span> <span class="s">'households'</span><span class="p">]</span>
<span class="n">X_train</span><span class="p">.</span><span class="n">loc</span><span class="p">[:,</span> <span class="n">skewed_columns</span><span class="p">]</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">[</span><span class="n">skewed_columns</span><span class="p">].</span><span class="nb">apply</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">)</span>
<span class="n">X_valid</span><span class="p">.</span><span class="n">loc</span><span class="p">[:,</span> <span class="n">skewed_columns</span><span class="p">]</span> <span class="o">=</span> <span class="n">X_valid</span><span class="p">[</span><span class="n">skewed_columns</span><span class="p">].</span><span class="nb">apply</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">)</span>
<span class="n">X_test</span><span class="p">.</span><span class="n">loc</span><span class="p">[:,</span> <span class="n">skewed_columns</span><span class="p">]</span> <span class="o">=</span> <span class="n">X_test</span><span class="p">[</span><span class="n">skewed_columns</span><span class="p">].</span><span class="nb">apply</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">)</span>
</code></pre></div></div>

<p>I want to avoid plots to convince you that these are the “skewed” columns, since this is not the objective of this post. So here you will have to trust me :)</p>

<p>Now let’s get to the meat. We will rely on Scikit-Learn a lot here, and in particular, on the <a href="https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html">Pipeline</a> objects, that streamlines the data preparation before model fitting. It implements the “fit-predict” paradigm, where fitting the entire pipeline is done on the training set, and then we can do prediction. The data preparation components in a pipeline are known in Scikit-Learn as <code class="language-plaintext highlighter-rouge">Transformer</code> objects, since they transform data. Do not confuse them with Transformer models, used for language tasks.</p>

<p>So let’s write a simple transformer that converts each column in a dataset with numerical features to a corresponding Legendre Vandermonde matrix. Recall, that we relied on  the<code class="language-plaintext highlighter-rouge">np.polynomial.legendre.legvander</code> NumPy method in our last post, and it turns out it naturally handles datasets with multiple columns. Let’s see an example:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span>
    <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">],</span>
    <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">],</span>
    <span class="p">[</span><span class="mf">0.8</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
<span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[[[ 1.     0.    -0.5   -0.     0.375]
  [ 1.     0.2   -0.44  -0.28   0.232]]

 [[ 1.     0.4   -0.26  -0.44  -0.113]
  [ 1.     0.6    0.04  -0.36  -0.408]]

 [[ 1.     0.8    0.46   0.08  -0.233]
  [ 1.     1.     1.     1.     1.   ]]]
</code></pre></div></div>

<p>The output is a 3D array - each scalar has been expanded into a vector five Legendre basis values, corresponding to the five basis functions of degree 4.</p>

<p>But to fit a model we don’t need a 3D array, but a 2D array - the rows are the samples, and the columns are the features. So all we need to do is squeeze the last two dimensions using a simple reshape operation - it will automatically horizontally concatenate the Legendre features from each column:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="mi">4</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[[ 1.     0.    -0.5   -0.     0.375  1.     0.2   -0.44  -0.28   0.232]
 [ 1.     0.4   -0.26  -0.44  -0.113  1.     0.6    0.04  -0.36  -0.408]
 [ 1.     0.8    0.46   0.08  -0.233  1.     1.     1.     1.     1.   ]]
</code></pre></div></div>

<p>Now we see two “blocks” of Legendre Vandermonde matrices, concatenated horizontally - one block for every column. Each block has 5 columns, corresponding to the five basis functions of degree 4.</p>

<p>Finally, note that we have a column of ones - this column is the “bias” term of each polynomial. But we don’t want a bias term for every polynomial - we want <em>one</em> bias term for the entire model. So we will have to remove these columns of ones, and let the linear regression model in Scikit-Learn have its own bias. Now we’re ready to write our transformer. There is some boilerplate, but everything substantial I already explained above.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.base</span> <span class="kn">import</span> <span class="n">TransformerMixin</span><span class="p">,</span> <span class="n">BaseEstimator</span>

<span class="k">class</span> <span class="nc">LegendreScalarPolynomialFeatures</span><span class="p">(</span><span class="n">TransformerMixin</span><span class="p">,</span> <span class="n">BaseEstimator</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">degree</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">include_bias</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">degree</span> <span class="o">=</span> <span class="n">degree</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">include_bias</span> <span class="o">=</span> <span class="n">include_bias</span>

    <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="c1"># There is nothing to learn
</span>		<span class="c1"># Legendre polynomials do not depend on the training data.
</span>        <span class="k">return</span> <span class="bp">self</span>

    <span class="k">def</span> <span class="nf">__sklearn_is_fitted__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="c1"># See above - it's always "fitted" by definition
</span>        <span class="k">return</span> <span class="bp">True</span>

    <span class="k">def</span> <span class="nf">transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="c1"># Make sure X is of the right type and shape
</span>        <span class="n">X</span> <span class="o">=</span> <span class="n">check_array</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">accept_sparse</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">ensure_all_finite</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

        <span class="c1"># create a Vandermonde matrix for each feature, and create a 3D array
</span>        <span class="c1"># of shape
</span>        <span class="n">vander</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">degree</span><span class="p">)</span>
        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="p">.</span><span class="n">include_bias</span><span class="p">:</span>
            <span class="c1"># discard the column of ones for each feature
</span>            <span class="n">vander</span> <span class="o">=</span> <span class="n">vander</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">:]</span> 

        <span class="c1"># reshape to concatenate the Vandermonde matrices horizontally
</span>		<span class="n">n_rows</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">result</span> <span class="o">=</span> <span class="n">vander</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">n_rows</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">result</span>
</code></pre></div></div>

<p>Let’s try it out:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">LegendreScalarPolynomialFeatures</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="mi">4</span><span class="p">).</span><span class="n">transform</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[[ 0.    -0.5   -0.     0.375  0.2   -0.44  -0.28   0.232]
 [ 0.4   -0.26  -0.44  -0.113  0.6    0.04  -0.36  -0.408]
 [ 0.8    0.46   0.08  -0.233  1.     1.     1.     1.   ]]
</code></pre></div></div>

<p>The same matrix we saw before, but <em>without</em> the columns of ones.</p>

<p>Let’s see how we can use our new shiny component inside a Scikit-Learn pipeline. Our pipeline will first scale the features to be in the range \([-1, 1]\), so that we are in the in the “operating region” of the Legendre basis, then it will convert the scaled features to Legendre polynomials using our component, and finally it will fit a simple linear regression model. Here is how we can build such a pipeline for polynomials of degree 8:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.preprocessing</span> <span class="kn">import</span> <span class="n">MinMaxScaler</span>
<span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">LinearRegression</span>
<span class="kn">from</span> <span class="nn">sklearn.pipeline</span> <span class="kn">import</span> <span class="n">Pipeline</span>

<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">([</span>
        <span class="p">(</span><span class="s">'minmaxscaler'</span><span class="p">,</span> <span class="n">MinMaxScaler</span><span class="p">(</span><span class="n">feature_range</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">clip</span><span class="o">=</span><span class="bp">True</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'polyfeats'</span><span class="p">,</span> <span class="n">LegendreScalarPolynomialFeatures</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="mi">8</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'model'</span><span class="p">,</span> <span class="n">LinearRegression</span><span class="p">()),</span>
    <span class="p">])</span>
</code></pre></div></div>

<p>The pipeline is simply composed of a sequence of steps, associated with a name of our choice. Here is how we fit it on the training data, and then compute the test error:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">test_error</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Test error = </span><span class="si">{</span><span class="n">test_error</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Test error = 63475.8142
</code></pre></div></div>

<p>It appears large, but we’re dealing here with sums of money representing housing prices - hundreds of thousands of dollars. So it’s not <em>that</em> large.</p>

<p>Now let’s try to reproduce our double descent. We will iterate over several degrees, fit a pipeline to the training data, and compute the test errors. Note, that we compute <em>test errors</em> and not <em>validation errors</em>, because we care about observing the generalization power, and <em>not</em> tuning some parameter. Moreover, to make it run in a reasonable time, we sample a subset of the training set - 5000 out of the 13700 rows. Here is the code:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">degrees</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">geomspace</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">40000</span><span class="p">,</span> <span class="mi">12</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">int32</span><span class="p">).</span><span class="n">tolist</span><span class="p">()</span>
<span class="n">train_rmses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">test_rmses</span> <span class="o">=</span> <span class="p">[]</span>

<span class="c1"># sample training set - note that it's already randomly permuted
# by the train-test split.
</span><span class="n">n_samples</span> <span class="o">=</span> <span class="mi">5000</span>
<span class="n">X_train_samples</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">.</span><span class="n">iloc</span><span class="p">[:</span><span class="n">n_samples</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">y_train_samples</span> <span class="o">=</span> <span class="n">y_train</span><span class="p">.</span><span class="n">iloc</span><span class="p">[:</span><span class="n">n_samples</span><span class="p">]</span>

<span class="c1"># fit various degrees
</span><span class="k">for</span> <span class="n">degree</span> <span class="ow">in</span> <span class="n">degrees</span><span class="p">:</span>
    <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">([</span>
        <span class="p">(</span><span class="s">'minmaxscaler'</span><span class="p">,</span> <span class="n">MinMaxScaler</span><span class="p">(</span><span class="n">feature_range</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">clip</span><span class="o">=</span><span class="bp">True</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'polyfeats'</span><span class="p">,</span> <span class="n">LegendreScalarPolynomialFeatures</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="n">degree</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'model'</span><span class="p">,</span> <span class="n">LinearRegression</span><span class="p">())</span>
    <span class="p">])</span>
    <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train_samples</span><span class="p">,</span> <span class="n">y_train_samples</span><span class="p">)</span>
    <span class="n">y_train_pred</span>  <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_train_samples</span><span class="p">)</span>
    <span class="n">y_test_pred</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span>

    <span class="n">train_rmses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_train_samples</span><span class="p">,</span> <span class="n">y_train_pred</span><span class="p">))</span>
    <span class="n">test_rmses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_test_pred</span><span class="p">))</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Degree: </span><span class="si">{</span><span class="n">degree</span><span class="si">}</span><span class="s">, "</span>
          <span class="sa">f</span><span class="s">"Test RMSE: </span><span class="si">{</span><span class="n">test_rmses</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">, "</span>
          <span class="sa">f</span><span class="s">"Train RMSE </span><span class="si">{</span><span class="n">train_rmses</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Degree: 1, Test RMSE: 68326.20, Train RMSE 66776.99
Degree: 2, Test RMSE: 67522.42, Train RMSE 65720.98
Degree: 3, Test RMSE: 65784.84, Train RMSE 64030.37
Degree: 4, Test RMSE: 65227.71, Train RMSE 63353.57
Degree: 5, Test RMSE: 64907.57, Train RMSE 62297.19
Degree: 6, Test RMSE: 65134.09, Train RMSE 62006.36
Degree: 7, Test RMSE: 64675.77, Train RMSE 61721.51
Degree: 8, Test RMSE: 64534.99, Train RMSE 61459.37
Degree: 9, Test RMSE: 64962.61, Train RMSE 61285.31
Degree: 10, Test RMSE: 64433.99, Train RMSE 60810.69
Degree: 21, Test RMSE: 1701699.31, Train RMSE 57939.34
Degree: 45, Test RMSE: 2786592773542.82, Train RMSE 55145.90
Degree: 96, Test RMSE: 24926913710269.14, Train RMSE 50786.59
Degree: 204, Test RMSE: 6879185912413.60, Train RMSE 47511.24
Degree: 433, Test RMSE: 5893934722602.69, Train RMSE 41896.23
Degree: 922, Test RMSE: 1642004977035.12, Train RMSE 32491.25
Degree: 1959, Test RMSE: 295197737240.98, Train RMSE 14681.05
Degree: 4165, Test RMSE: 116144.54, Train RMSE 0.00
Degree: 8854, Test RMSE: 85373.33, Train RMSE 0.00
Degree: 18819, Test RMSE: 78639.33, Train RMSE 0.00
Degree: 40000, Test RMSE: 75965.36, Train RMSE 0.00
</code></pre></div></div>

<p>We see a nice double descent! The train error goes down towards zero. The test error first increases with the degree, and then decreases again! It is not obvious where the “memorization threshold” is now, since the features are correlated. For example, total rooms and total bedrooms are correlated:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">corr</span><span class="p">(</span><span class="n">X_train</span><span class="p">[</span><span class="s">"total_rooms"</span><span class="p">],</span> <span class="n">X_train</span><span class="p">[</span><span class="s">"total_bedrooms"</span><span class="p">])</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[[1.         0.94465316]
 [0.94465316 1.        ]]
</code></pre></div></div>

<p>This means, for instance, that the block of Legendre features for total bedrooms does not necessarily add more information. Thus, it’s not very trivial where this “memorization threshold” is in terms of the polynomial degree. But it is somewhere between 1959 and 4165, since we see that the train error drops to zero somewhere in between.</p>

<p>We can also plot the double descent curve using the train and test errors we just stored:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">degrees</span><span class="p">,</span> <span class="n">train_rmses</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Train"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">degrees</span><span class="p">,</span> <span class="n">test_rmses</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Test"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylim</span><span class="p">([</span><span class="o">-</span><span class="mf">0.1</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">test_rmses</span><span class="p">)])</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"RMSE"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xscale</span><span class="p">(</span><span class="s">'log'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"Polynomial degree"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_yscale</span><span class="p">(</span><span class="s">'asinh'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>

<span class="n">fig</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/california_housing_legendre_double_descent.png" alt="california_housing_legendre_double_descent" /></p>

<p>It’s interesting to see that the test error of polynomial features of degree 40,000 is quite small, but it it’s worse than that if the low degree polynomials. I’m pretty sure that if we crank up the degree to a few millions it will be better, but that would be an overkill. Having demonstrated the double descent, I want to take this post in a different direction.</p>

<h1 id="pruning">Pruning</h1>

<p>First, let’s see if pruning the “tail” of the polynomial even makes sense - meaning that higher degrees simply add more intricate details to an already well-formed polynomial. To that end, let’s plot the polynomial we obtained for each feature by taking only the first \(k\) coefficients, for various values of \(k\).</p>

<p>Recall, that our pipeline has a step named <code class="language-plaintext highlighter-rouge">model</code>, which is a linear regression model. We can access its coefficients:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">lin_reg</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">named_steps</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">lin_reg</span><span class="p">.</span><span class="n">coef_</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(320000,)
</code></pre></div></div>

<p>We see that we have exactly \(8 \times 40{,}000 = 320{,}000\) coefficients. This is because we have 8 columns, each represented by 40,000 coefficients of a polynomial of degree 40,000 without its bias term. We can access the coefficients of each polynomial by reshaping these coefficients into a matrix of 8 rows. This is exactly what the following function does - extracts the coefficient matrix, with a row of coefficients for each feature:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_feature_coefs</span><span class="p">(</span><span class="n">pipeline</span><span class="p">):</span>
    <span class="n">num_features</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">named_steps</span><span class="p">[</span><span class="s">'minmaxscaler'</span><span class="p">].</span><span class="n">n_features_in_</span>
    <span class="n">lin_reg</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">named_steps</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span>
    <span class="n">feature_coefs</span> <span class="o">=</span> <span class="n">lin_reg</span><span class="p">.</span><span class="n">coef_</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_features</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">feature_coefs</span>
</code></pre></div></div>

<p>Now we can use it to plot the coefficients of each feature. We call the coefficients vector a <em>spectrum</em>, because Legendre polynomials model oscilations, just like sines and cosines.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">plot_spectra</span><span class="p">(</span><span class="n">pipeline</span><span class="p">):</span>
    <span class="n">feature_coefs</span> <span class="o">=</span> <span class="n">get_feature_coefs</span><span class="p">(</span><span class="n">pipeline</span><span class="p">)</span>

    <span class="c1"># define subplots for each feature
</span>    <span class="n">n_cols</span> <span class="o">=</span> <span class="mi">3</span>
    <span class="n">n_rows</span> <span class="o">=</span> <span class="n">math</span><span class="p">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">feature_coefs</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">n_cols</span><span class="p">)</span>
    <span class="n">width</span><span class="p">,</span> <span class="n">height</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">rcParams</span><span class="p">[</span><span class="s">'figure.figsize'</span><span class="p">]</span>
    <span class="n">fig</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span>
        <span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">[</span><span class="n">n_cols</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">n_rows</span> <span class="o">*</span> <span class="n">height</span><span class="p">],</span>
        <span class="n">sharex</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
        <span class="n">layout</span><span class="o">=</span><span class="s">'constrained'</span><span class="p">)</span>

    <span class="c1"># plot coefficients
</span>    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">coef_vec</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">feature_coefs</span><span class="p">,</span> <span class="n">axs</span><span class="p">.</span><span class="n">ravel</span><span class="p">())):</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">coef_vec</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">'</span><span class="si">{</span><span class="n">X_train</span><span class="p">.</span><span class="n">columns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>

    <span class="n">fig</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>

<span class="n">plot_spectra</span><span class="p">(</span><span class="n">pipeline</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/california_housing_spectra.png" alt="california_housing_spectra" /></p>

<p>Indeed, for each one of the features the coefficients “decay” towards zero. There are features where it happens quickly, and those where it happens slowly, but it happens for all of them. So “pruning” makes sense - we remove fine details modeled by the rapid oscillations of the high degree polynomials, and remain with the overall shape.</p>

<p>So let’s define a function that prunes each polynomial by preserving only the initial \(k\) coefficients for several values of \(k\), and see the results. The function is a bit lengthy due to boilerplate, but pretty straightforward:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">plot_feature_curves</span><span class="p">(</span>
        <span class="n">pipeline</span><span class="p">,</span> <span class="n">pruned_degrees</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">2000</span><span class="p">],</span> <span class="n">plot_resolution</span><span class="o">=</span><span class="mi">5000</span>
<span class="p">):</span>
    <span class="c1"># extract coefficients of each feature from the model
</span>    <span class="n">feature_coefs</span> <span class="o">=</span> <span class="n">get_feature_coefs</span><span class="p">(</span><span class="n">pipeline</span><span class="p">)</span>
    <span class="n">num_features</span> <span class="o">=</span> <span class="n">feature_coefs</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

    <span class="c1"># define grid of plots for features x degrees
</span>    <span class="n">width</span><span class="p">,</span> <span class="n">height</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">rcParams</span><span class="p">[</span><span class="s">'figure.figsize'</span><span class="p">]</span>
    <span class="n">fig</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span>
        <span class="n">feature_coefs</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">len</span><span class="p">(</span><span class="n">pruned_degrees</span><span class="p">),</span>
        <span class="n">layout</span><span class="o">=</span><span class="s">'constrained'</span><span class="p">,</span> <span class="n">sharex</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">sharey</span><span class="o">=</span><span class="s">'row'</span><span class="p">,</span>
        <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">width</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">pruned_degrees</span><span class="p">)</span> <span class="o">/</span> <span class="mi">3</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">num_features</span> <span class="o">/</span> <span class="mi">3</span><span class="p">))</span>

    <span class="c1"># compute Legendre Vandermonde matrix of maximum degree, to be pruned
</span>    <span class="n">x_plot</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">plot_resolution</span><span class="p">)</span>
    <span class="n">full_vander</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">(</span><span class="n">x_plot</span><span class="p">,</span> <span class="nb">max</span><span class="p">(</span><span class="n">pruned_degrees</span><span class="p">))</span>

    <span class="c1"># do the plotting for each feature and degree
</span>    <span class="k">for</span> <span class="n">feat</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_features</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">degree</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">pruned_degrees</span><span class="p">):</span>
            <span class="c1"># prune coefficients of current feature at current degree
</span>            <span class="n">pruned_coefs</span> <span class="o">=</span> <span class="n">feature_coefs</span><span class="p">[</span><span class="n">feat</span><span class="p">,</span> <span class="p">:</span><span class="mi">1</span> <span class="o">+</span> <span class="n">degree</span><span class="p">]</span>

            <span class="c1"># prune Vandermonde matrix up to current degree
</span>            <span class="n">pruned_vander</span> <span class="o">=</span> <span class="n">full_vander</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">1</span> <span class="o">+</span> <span class="n">degree</span><span class="p">]</span>

            <span class="c1"># plot the current degree polynomial
</span>            <span class="n">y_plot</span> <span class="o">=</span> <span class="n">pruned_vander</span> <span class="o">@</span> <span class="n">pruned_coefs</span>
            <span class="n">axs</span><span class="p">[</span><span class="n">feat</span><span class="p">,</span> <span class="n">i</span><span class="p">].</span><span class="n">plot</span><span class="p">(</span><span class="n">x_plot</span><span class="p">,</span> <span class="n">y_plot</span><span class="p">)</span>

            <span class="c1"># put axis titles
</span>            <span class="k">if</span> <span class="n">feat</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">axs</span><span class="p">[</span><span class="n">feat</span><span class="p">,</span> <span class="n">i</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">"deg=</span><span class="si">{</span><span class="n">degree</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
            <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">feature_name</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">.</span><span class="n">columns</span><span class="p">[</span><span class="n">feat</span><span class="p">].</span><span class="n">replace</span><span class="p">(</span><span class="s">"_"</span><span class="p">,</span> <span class="s">"</span><span class="se">\n</span><span class="s">"</span><span class="p">)</span>
                <span class="n">axs</span><span class="p">[</span><span class="n">feat</span><span class="p">,</span> <span class="n">i</span><span class="p">].</span><span class="n">set_ylabel</span><span class="p">(</span><span class="n">feature_name</span><span class="p">)</span>

    <span class="n">fig</span><span class="p">.</span><span class="n">align_ylabels</span><span class="p">(</span><span class="n">axs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
    <span class="n">fig</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/california_housing_pruned_polys.png" alt="california_housing_pruned_polys" /></p>

<p>We can see some interesting things. For example, for the <em>population</em> feature, the polynomial of degree 3 captures the overall shape quite well. For the median income, the lower degrees also do quite a good job. Moreover, we see that we have more oscillations that look like noise added on top of the overall shape for some of the features, notably, total rooms, total bedrooms, population, and households. These are exactly the features for which the spectrum decays slowly - higher degrees that model rapid oscillations have a larger effect.</p>

<p>So here comes an interesting conjecture - maybe keeping just the “overall shape” by pruning the higher order coefficients we can achieve an even better generalization error? Well, let’s put it to a test! Here is a small function that takes a pipeline, and creates a new one of a lower degree with pruned coefficients. The code is quite straightforward - just remove the “tail” of coefficients, and make the polynomials of the corresponding degree:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">prune_pipeline</span><span class="p">(</span><span class="n">pipeline</span><span class="p">,</span> <span class="n">pruned_deg</span><span class="p">):</span>
    <span class="n">pruned</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">pipeline</span><span class="p">)</span>
    <span class="n">num_features</span> <span class="o">=</span> <span class="n">pruned</span><span class="p">.</span><span class="n">named_steps</span><span class="p">[</span><span class="s">'minmaxscaler'</span><span class="p">].</span><span class="n">min_</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">orig_degree</span> <span class="o">=</span> <span class="n">pruned</span><span class="p">.</span><span class="n">named_steps</span><span class="p">[</span><span class="s">'polyfeats'</span><span class="p">].</span><span class="n">degree</span>
    <span class="n">pruned</span><span class="p">.</span><span class="n">named_steps</span><span class="p">[</span><span class="s">'polyfeats'</span><span class="p">].</span><span class="n">degree</span> <span class="o">=</span> <span class="n">pruned_deg</span>

    <span class="n">lin_reg</span> <span class="o">=</span> <span class="n">pruned</span><span class="p">.</span><span class="n">named_steps</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span>
    <span class="n">orig_coef</span> <span class="o">=</span> <span class="n">lin_reg</span><span class="p">.</span><span class="n">coef_</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_features</span><span class="p">,</span> <span class="n">orig_degree</span><span class="p">)</span>
    <span class="n">pruned_coef</span> <span class="o">=</span> <span class="n">orig_coef</span><span class="p">[:,</span> <span class="p">:</span><span class="n">pruned_deg</span><span class="p">].</span><span class="n">ravel</span><span class="p">()</span>
    <span class="n">lin_reg</span><span class="p">.</span><span class="n">coef_</span> <span class="o">=</span> <span class="n">pruned_coef</span>
    <span class="n">lin_reg</span><span class="p">.</span><span class="n">n_features_in_</span> <span class="o">=</span> <span class="n">pruned_coef</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

    <span class="k">return</span> <span class="n">pruned</span>
</code></pre></div></div>

<p>Now we can create degree-pruned pipelines, select the degree giving us the smallest <em>validation error</em>, and compute the test error. Note, that this is the first time we’re using our validation set, because here we are actually tuning a parameter - we are tuning the pruned pipeline’s degree:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># get a set of degrees to try pruning at.
</span><span class="n">prune_degrees</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span>
    <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">100</span><span class="p">),</span>
    <span class="n">np</span><span class="p">.</span><span class="n">geomspace</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">5000</span><span class="p">,</span> <span class="mi">50</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span>
<span class="p">])</span>

<span class="c1"># compute validation error for each pruned degree
</span><span class="n">pruned_errors</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">prune_degrees</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">degree</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">prune_degrees</span><span class="p">):</span>
    <span class="n">pruned</span> <span class="o">=</span> <span class="n">prune_pipeline</span><span class="p">(</span><span class="n">pipeline</span><span class="p">,</span> <span class="n">degree</span><span class="p">)</span>
    <span class="n">y_valid_pred</span> <span class="o">=</span> <span class="n">pruned</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_valid</span><span class="p">)</span>
    <span class="n">pruned_errors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_valid</span><span class="p">,</span> <span class="n">y_valid_pred</span><span class="p">)</span>
    
<span class="c1"># compute the test error for the optimal degree
</span><span class="n">best_degree</span> <span class="o">=</span> <span class="n">prune_degrees</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">argmin</span><span class="p">(</span><span class="n">pruned_errors</span><span class="p">)]</span>
<span class="n">pruned</span> <span class="o">=</span> <span class="n">prune_pipeline</span><span class="p">(</span><span class="n">pipeline</span><span class="p">,</span> <span class="n">best_degree</span><span class="p">)</span>
<span class="n">pruned_test_error</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">pruned</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Best degree = </span><span class="si">{</span><span class="n">best_degree</span><span class="si">}</span><span class="s">, test error = </span><span class="si">{</span><span class="n">pruned_test_error</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Best degree = 100, test error = 60014.842969972255
</code></pre></div></div>

<p>Much better than the test error we obtained with pure fitting! When we plotted the double descent curve, the best test error was obtained for a polynomial of degree 8, and it was 64534.99. That’s a very nice improvement of approximately 7% in RMSE!</p>

<p>Obviously, we could also try a <em>regularized</em> fit, and tune the regularization coefficient. It may even yield a better generalization error, but this misses the point. Here, it is a <em>post fitting</em> procedure. We first fit a model, without tuning anything, and then tune it by pruning coefficients. A lot of coefficients! We reduce the model from being a polynomial of degree 40,000 to a polynomial of degree 100!</p>

<p>Let’s plot the validation error as a function of the pruned degree, and also add the original best test error, the best pruned validation error, and the best pruned test error to the plot:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">prune_degrees</span><span class="p">,</span> <span class="n">pruned_errors</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'Pruned valid'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"Pruned degree"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"Test RMSE"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xscale</span><span class="p">(</span><span class="s">'asinh'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axhline</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">pruned_errors</span><span class="p">),</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'--'</span><span class="p">,</span>
            <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">'Best pruned: </span><span class="si">{</span><span class="n">np</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">pruned_errors</span><span class="p">)</span><span class="si">:</span><span class="p">.</span><span class="mi">3</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'orange'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axhline</span><span class="p">(</span><span class="n">pruned_test_error</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'--'</span><span class="p">,</span>
            <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">'Best pruned test: </span><span class="si">{</span><span class="n">pruned_test_error</span><span class="si">:</span><span class="p">.</span><span class="mi">3</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'red'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axhline</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">test_rmses</span><span class="p">),</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'--'</span><span class="p">,</span>
            <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">'Best fit: </span><span class="si">{</span><span class="n">np</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">test_rmses</span><span class="p">)</span><span class="si">:</span><span class="p">.</span><span class="mi">3</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'navy'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/california_housing_pruned_errors.png" alt="california_housing_pruned_errors" /></p>

<p>The blue curve shows us that pruning the polynomials to a degree of approximately 100 is a good choice in terms of the validation error. The corresponding test error is shown in red. For comparison, the test error obtained by fitting a polynomial of degree 8 is in blue.</p>

<p>Intuitively, it appears that fitting polynomials of higher degrees lets the model separate “signal” from “noise”, and by pruning we remove the noise and stay with the signal. It’s not a rigorous analysis, but it’s an intuition that appears to make sense given the decaying spectrum and the fact that we saw visually that lower degree polynomials indeed capture the high level shape. The fact that Legendre polynomials act like a frequency spectrum lets us “distil” the simple model hiding inside the highly overparameterized model explicitly. It’s not even hiding - it’s in plain sight, in the lower-degree coefficients.</p>

<p>So let’s take it one step further. We can actually do some greedy pruning of the polynomial of each feature separately using our validation set. Since it’s convenient that all polynomials are of the same degree, so we can store them in a matrix, we will do the pruning by zeroing out the corresponding tail of coefficients for each feature. Here is a function that prunes the degree of <em>one</em> given feature by zeroing the tail coefficients. The code is quite straightforward:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">prune_feature</span><span class="p">(</span><span class="n">pipeline</span><span class="p">,</span> <span class="n">feature</span><span class="p">,</span> <span class="n">pruned_deg</span><span class="p">):</span>
    <span class="n">pruned</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">pipeline</span><span class="p">)</span>
    <span class="n">num_features</span> <span class="o">=</span> <span class="n">pruned</span><span class="p">.</span><span class="n">named_steps</span><span class="p">[</span><span class="s">'minmaxscaler'</span><span class="p">].</span><span class="n">n_features_in_</span>
    <span class="n">full_deg</span> <span class="o">=</span> <span class="n">pruned</span><span class="p">.</span><span class="n">named_steps</span><span class="p">[</span><span class="s">'polyfeats'</span><span class="p">].</span><span class="n">degree</span>

    <span class="n">regressor</span> <span class="o">=</span> <span class="n">pruned</span><span class="p">.</span><span class="n">named_steps</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span>
    <span class="n">coef</span> <span class="o">=</span> <span class="n">regressor</span><span class="p">.</span><span class="n">coef_</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_features</span><span class="p">,</span> <span class="n">full_deg</span><span class="p">)</span>
    <span class="n">coef</span><span class="p">[</span><span class="n">feature</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">pruned_deg</span><span class="p">):]</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">regressor</span><span class="p">.</span><span class="n">coef_</span> <span class="o">=</span> <span class="n">coef</span><span class="p">.</span><span class="n">ravel</span><span class="p">()</span>

    <span class="k">return</span> <span class="n">pruned</span>
</code></pre></div></div>

<p>Now we can loop over a set of polynomial degrees for each feature, and select the degree that gives the best validation error:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">pruned_pipeline</span> <span class="o">=</span> <span class="n">prune_pipeline</span><span class="p">(</span><span class="n">pipeline</span><span class="p">,</span> <span class="mi">200</span><span class="p">)</span>
<span class="n">degrees_to_try</span> <span class="o">=</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">200</span><span class="p">)</span>

<span class="k">for</span> <span class="n">feature</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">X_train</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
    <span class="n">best_deg</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">best_error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">inf</span>
    <span class="k">for</span> <span class="n">deg</span> <span class="ow">in</span> <span class="n">degrees_to_try</span><span class="p">:</span>
        <span class="n">candidate</span> <span class="o">=</span> <span class="n">prune_feature</span><span class="p">(</span><span class="n">pruned_pipeline</span><span class="p">,</span> <span class="n">feature</span><span class="p">,</span> <span class="n">deg</span><span class="p">)</span>
        <span class="n">pred</span> <span class="o">=</span> <span class="n">candidate</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_valid</span><span class="p">)</span>
        <span class="n">error</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_valid</span><span class="p">,</span> <span class="n">pred</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">error</span> <span class="o">&lt;=</span> <span class="n">best_error</span><span class="p">:</span>
            <span class="n">best_error</span> <span class="o">=</span> <span class="n">error</span>
            <span class="n">best_deg</span> <span class="o">=</span> <span class="n">deg</span>

    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Best degree for feature </span><span class="si">{</span><span class="n">X_train</span><span class="p">.</span><span class="n">columns</span><span class="p">[</span><span class="n">feature</span><span class="p">]</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">best_deg</span><span class="si">}</span><span class="s">, validation error: </span><span class="si">{</span><span class="n">best_error</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
    <span class="n">pruned_pipeline</span> <span class="o">=</span> <span class="n">prune_feature</span><span class="p">(</span><span class="n">pruned_pipeline</span><span class="p">,</span> <span class="n">feature</span><span class="p">,</span> <span class="n">best_deg</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Best degree for feature longitude: 188, validation error: 59310.36603502216
Best degree for feature latitude: 198, validation error: 59302.68099683055
Best degree for feature housing_median_age: 16, validation error: 59072.43414192942
Best degree for feature total_rooms: 23, validation error: 58793.32412447594
Best degree for feature total_bedrooms: 34, validation error: 58246.908681064
Best degree for feature population: 6, validation error: 57845.32710297523
Best degree for feature households: 14, validation error: 57614.71621402189
Best degree for feature median_income: 8, validation error: 56791.08354124084
</code></pre></div></div>

<p>We can see that some features need higher degrees to represent the right overall shape, whereas others work well with lower degrees. This simple post-training procedure lets us actually customize the polynomial degree of each feature separately! Now let’s see the test error:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test_error</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">candidate</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Best pipeline test error: "</span><span class="p">,</span> <span class="n">test_error</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Best pipeline test error:  59497.31961895119
</code></pre></div></div>

<p>It appears we squeezed an additional half a percent. The test error went down from 60014.84 to 59497.32 by customizing the right degree for each feature.</p>

<h1 id="comparing-to-regularized-regression">Comparing to regularized regression</h1>

<p>Pruning is all nice, but don’t we all learn to actually use regularization and tune the regularization coefficient? Well, let’s try this as well. It’s quite simple - construct a similar pipeline with a <code class="language-plaintext highlighter-rouge">Ridge</code> regression object, that adds L2 regularization to least-squares regression. Then, tune both the regularization coefficient and the degree of the polynomials using <a href="https://hyperopt.github.io/hyperopt/">HyperOpt</a>, which is a pretty good hyperparameter tuner that comes preinstalled with Colab. First, we define a function that creates a pipeline with Ridge regression model given a polynomial degree and the regularization coefficient:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">Ridge</span>

<span class="k">def</span> <span class="nf">make_ridge_pipeline</span><span class="p">(</span><span class="n">degree</span><span class="p">,</span> <span class="n">alpha</span><span class="p">):</span>
    <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">([</span>
        <span class="p">(</span><span class="s">'minmaxscaler'</span><span class="p">,</span> <span class="n">MinMaxScaler</span><span class="p">(</span><span class="n">feature_range</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">clip</span><span class="o">=</span><span class="bp">True</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'polyfeats'</span><span class="p">,</span> <span class="n">LegendreScalarPolynomialFeatures</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="n">degree</span><span class="p">)),</span>
        <span class="p">(</span><span class="s">'model'</span><span class="p">,</span> <span class="n">Ridge</span><span class="p">(</span><span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">)),</span>
    <span class="p">])</span>
    <span class="k">return</span> <span class="n">pipeline</span>
</code></pre></div></div>

<p>Now, to employ HyperOpt, we define a function that computes the quality of a set of hyper parameters by fitting a model and evaluating it on a validation set. While defining it, we annotate each hyperparameter with how it should be searched - degrees are search uniformly, whereas regularization coefficients are searched in log-space. Finally invoke HyperOpt’s <code class="language-plaintext highlighter-rouge">fmin</code> function that is going to search the space for the best hyperparameters. Here is the code:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">hyperopt</span> <span class="kn">import</span> <span class="n">hp</span><span class="p">,</span> <span class="n">fmin</span><span class="p">,</span> <span class="n">tpe</span>

<span class="k">def</span> <span class="nf">score</span><span class="p">(</span>
        <span class="n">degree</span><span class="p">:</span> <span class="n">hp</span><span class="p">.</span><span class="n">uniformint</span><span class="p">(</span><span class="s">'degree'</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">500</span><span class="p">),</span>
        <span class="n">alpha</span><span class="p">:</span> <span class="n">hp</span><span class="p">.</span><span class="n">loguniform</span><span class="p">(</span><span class="s">'alpha'</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mf">1e-3</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mf">1e3</span><span class="p">))</span>
<span class="p">):</span>
    <span class="n">pipeline</span> <span class="o">=</span> <span class="n">make_ridge_pipeline</span><span class="p">(</span><span class="n">degree</span><span class="p">,</span> <span class="n">alpha</span><span class="p">)</span>
    <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train_samples</span><span class="p">,</span> <span class="n">y_train_samples</span><span class="p">)</span>
    <span class="n">y_pred</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_valid</span><span class="p">)</span>
    <span class="n">error</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_valid</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">error</span>

<span class="n">best_params</span> <span class="o">=</span> <span class="n">fmin</span><span class="p">(</span>
    <span class="n">score</span><span class="p">,</span> <span class="n">space</span><span class="o">=</span><span class="s">'annotated'</span><span class="p">,</span> <span class="n">algo</span><span class="o">=</span><span class="n">tpe</span><span class="p">.</span><span class="n">suggest</span><span class="p">,</span> <span class="n">max_evals</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span>
    <span class="n">rstate</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">default_rng</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="p">)</span>

<span class="k">print</span><span class="p">(</span><span class="n">best_params</span><span class="p">)</span>
</code></pre></div></div>

<p>The <code class="language-plaintext highlighter-rouge">fmin</code> function shows a progress bar of the 500 trials, which took approximately four minutes, and then we print the best parameters:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>100%|██████████| 500/500 [04:04&lt;00:00,  2.05trial/s, best loss: 57888.41179388089]
{'alpha': np.float64(8.319252707439558), 'degree': np.float64(102.0)}
</code></pre></div></div>

<p>We see that our hyperparameter search polynomials of degree 102 - so the Ridge regression isn’t afraid of high degree polynomials either 😀. What about the test error? Let’s fit a model with the best hyperparameters, and compute the test error:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">best_degree</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">best_params</span><span class="p">[</span><span class="s">'degree'</span><span class="p">])</span>
<span class="n">best_alpha</span> <span class="o">=</span> <span class="n">best_params</span><span class="p">[</span><span class="s">'alpha'</span><span class="p">]</span>
<span class="n">best_pipeline</span> <span class="o">=</span> <span class="n">make_ridge_pipeline</span><span class="p">(</span><span class="n">best_degree</span><span class="p">,</span> <span class="n">best_alpha</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train_samples</span><span class="p">,</span> <span class="n">y_train_samples</span><span class="p">)</span>

<span class="n">test_error</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">best_pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Test error = </span><span class="si">{</span><span class="n">test_error</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Test error = 59598.0548
</code></pre></div></div>

<p>Very close to what we achieved with pruning. So our pruned model is not bad at all, and can be seen as a reasonable baseline.But wait - we have a pretty high degree polynomial here as well. It’s not of degree 40,000, but “only” 102, but it’s still high. Maybe we can further improve the model by applying the same pruning trick to the regularized model?</p>

<p>Just for the sake of it - let’s look how the spectra of the polynomials the regularized model learned look like.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plot_spectra</span><span class="p">(</span><span class="n">best_pipeline</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/california_housing_regularized_spectra.png" alt="california_housing_regularized_spectra" /></p>

<p>At first glance the bahavior seems similar - the coefficients decay towards zero, some more rapidly, some slower. Now let’s prune the model to see what happens to our test error, by applying the same per-feature pruning logic we had before:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">degrees_to_try</span> <span class="o">=</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">best_degree</span><span class="p">)</span>
<span class="n">pruned_pipeline</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">best_pipeline</span><span class="p">)</span>

<span class="k">for</span> <span class="n">feature</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">X_train</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
    <span class="n">best_deg</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">best_error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">inf</span>
    <span class="k">for</span> <span class="n">deg</span> <span class="ow">in</span> <span class="n">degrees_to_try</span><span class="p">:</span>
        <span class="n">candidate</span> <span class="o">=</span> <span class="n">prune_feature</span><span class="p">(</span><span class="n">pruned_pipeline</span><span class="p">,</span> <span class="n">feature</span><span class="p">,</span> <span class="n">deg</span><span class="p">)</span>
        <span class="n">pred</span> <span class="o">=</span> <span class="n">candidate</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_valid</span><span class="p">)</span>
        <span class="n">error</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_valid</span><span class="p">,</span> <span class="n">pred</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">error</span> <span class="o">&lt;=</span> <span class="n">best_error</span><span class="p">:</span>
            <span class="n">best_error</span> <span class="o">=</span> <span class="n">error</span>
            <span class="n">best_deg</span> <span class="o">=</span> <span class="n">deg</span>

    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Best degree for feature </span><span class="si">{</span><span class="n">X_train</span><span class="p">.</span><span class="n">columns</span><span class="p">[</span><span class="n">feature</span><span class="p">]</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">best_deg</span><span class="si">}</span><span class="s">, validation error: </span><span class="si">{</span><span class="n">best_error</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
    <span class="n">pruned_pipeline</span> <span class="o">=</span> <span class="n">prune_feature</span><span class="p">(</span><span class="n">pruned_pipeline</span><span class="p">,</span> <span class="n">feature</span><span class="p">,</span> <span class="n">best_deg</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Best degree for feature longitude: 100, validation error: 57886.75083218771
Best degree for feature latitude: 99, validation error: 57884.23600560439
Best degree for feature housing_median_age: 77, validation error: 57832.46845479303
Best degree for feature total_rooms: 1, validation error: 57258.77663424223
Best degree for feature total_bedrooms: 13, validation error: 56978.72710952217
Best degree for feature population: 27, validation error: 56673.91509748052
Best degree for feature households: 27, validation error: 56453.035725897564
Best degree for feature median_income: 8, validation error: 56142.80647476099
</code></pre></div></div>

<p>Note something interesting - the total rooms feature got a <em>linear</em> function, it’s degree is one. But the total bedrooms feature got a polynomial of degree 13. We saw before that these two features are highly correlated - so remembering a lot of parameters for both of them, intuitively, wouldn’t make sense. It’s just a conjecture,  since I don’t have a formal proof, but I believe that the correlation between these two features we saw before is what made the regularized Ridge model “choose” to put the overall slope into the total rooms polynomial, and the finer details into the total bedrooms polynomial.</p>

<p>What about the test error?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test_error</span> <span class="o">=</span> <span class="n">root_mean_squared_error</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">pruned_pipeline</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Test error = </span><span class="si">{</span><span class="n">test_error</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Test error = 58598.2801
</code></pre></div></div>

<p>Nice! We just reduced it from 59598.05 to 58598.28, by 1000 dollars, just by using the fact that the Legendre basis acts like a frequency spectrum whose higher frequencies can be pruned.</p>

<h1 id="summary">Summary</h1>

<p>The idea of truncating the spectrum of a function is, of course, not new. It probably dates back to Fourier, and his celebrated Fourier series. Of course, Fourier series are a great fit for <em>periodic</em> functions, such as models as a function of the time of day. But it may not be a very good fit for a generic feature that exhibits no periodic nature.</p>

<p>Legendre polynomials are one example of a non-periodic “spectrum” composed of orthogonal functions. We pointed out in the previous post that <a href="https://en.wikipedia.org/wiki/Chebyshev_polynomials">Chebyshev polynomials</a> are another famous example. The ideas of using these bases to represent solutions of differential equations are abundant in science and engineering, and this post partially drew inspiration from the phenomenal course of Prof. Nick Trefethen “Approximation Theory and Approximation Practice”, and his famous book<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup> by the same name. Another great reference is John P. Boyd’s book <sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup> on spectral methods for differential equations. If you weren’t yet exposed to these subjects - I assure you a fun and enlightening experience. There is plenty of things the numerical analysis community studied thoroughly that you might find interesting.</p>

<p>The ideas of representing by truncating a series of orthogonal functions is abundant in signal processing, and entire research streams on signal and image denoising were built on top of this idea. What we did here was just drawing some inspiration from other scientific fields into machine learning. In some sense, we “denoised” the high degree polynomials by truncating them to lower degrees. And we saw it in two different contexts - without explicit regularization, relying on double-descent with extremely high degrees, and with explicit regularization, when the degrees are lower. In both cases - the idea is the same.</p>

<p>The idea here is by no means the best way to build a simple model with polynomial features, and it may be the case that a more thorough attempt to fit a regularized model may yield a better test error. However, the objective of this post is different - it’s gaining a new insight. It’s understanding that over-parametrization is what lets the model learn, automatically,  to separate signal from noise. That the Legendre polynomial basis lets us elicit this separation <em>explicitly</em> - by observing that the fit model has a decaying coefficient spectrum that can be “denoised”. This wonderful property let us extract the simple model hiding inside the over-parametried model. And it is exactly this insight that will take us to the next posts in this series!</p>

<hr />

<h1 id="references">References</h1>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:1" role="doc-endnote">
      <p>Trefethen, L. N. (2019). <em>Approximation theory and approximation practice, extended edition</em>. Society for Industrial and Applied Mathematics. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>Boyd, J. P. (2001). <em>Chebyshev and Fourier spectral methods</em>. Courier Corporation. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><category term="machine learning" /><category term="feature engineering" /><category term="polynomials" /><category term="polynomial regression" /><category term="Legendre polynomials" /><category term="double descent" /><category term="model pruning" /><category term="feature selection" /><category term="scikit-learn" /><category term="California housing dataset" /><summary type="html"><![CDATA[Legendre polynomial feature regression on California Housing shows double descent; a simple tail-pruning of high-degree coefficients yields smaller, competitive models. Implemented in scikit-learn.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="https://alexshtf.github.io/assets/california_housing_pruned_polys.png" /><media:content medium="image" url="https://alexshtf.github.io/assets/california_housing_pruned_polys.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Let the polynomial monster free</title><link href="https://alexshtf.github.io/2025/03/27/Free-Poly.html" rel="alternate" type="text/html" title="Let the polynomial monster free" /><published>2025-03-27T00:00:00+00:00</published><updated>2025-03-27T00:00:00+00:00</updated><id>https://alexshtf.github.io/2025/03/27/Free-Poly</id><content type="html" xml:base="https://alexshtf.github.io/2025/03/27/Free-Poly.html"><![CDATA[<h1 id="intro">Intro</h1>

<p>In a recent post by Ben Recht, titled <a href="https://www.argmin.net/p/thou-shalt-not-overfit">Though Shalt Not Overfit</a>, Ben claims that overfitting in the way that it is colloquially described in data science and machine learning, doesn’t exist.  Indeed, there is the famous <a href="https://en.wikipedia.org/wiki/Double_descent">double descent</a>:  trained neural networks that have <em>much</em> more parameters that needed to memorize the training set, tend generalize quite well<sup id="fnref:4" role="doc-noteref"><a href="#fn:4" class="footnote" rel="footnote">1</a></sup>. This includes most of our modern LLMs. So in some cases, models that achieve low training error tend to generalize badly not because their number of parameters is too high, but because it is <em>not high enough</em>!</p>

<p>In his post, Ben Recht claims that what we call “overfitting” is often just a post-hoc rationalization of a model being <em>wrong</em>, and makes us ignore the actual underlying issue that causes it to be wrong. For example, it just may be the case that our model is simply missing some important features. Ben’s post caused some backlash on <a href="https://x.com/beenwrekt/status/1884988534307873251">X</a> with several people disagreeing, such as <a href="https://x.com/KabirCreates/status/1884992728880283975">here</a>, and <a href="https://x.com/RVrijj/status/1885056275714629914">here</a> where the authors came back to the textbook examples of overfitting high degree polynomials.</p>

<p>In previous posts in this series we saw that high degree polynomials may be very useful in machine learning. We explored the fitting polynomials using the Bernstein basis, and its ability to control the shape of the polynomial that we introduced in in the post <a href="/2024/01/25/Bernstein-Basis.html">“Keeping the polynomial monster under control”</a>. But this time we would like to explore the exact opposite direction - we relinquish control, and set the polynomial monster free.  Do high degree polynomials exhibit this double-descent, just like neural networks in general  and LLMs in particular? Do they generalize well when the degree is much higher than what is needed to memorize the training data? Are the ML textbooks that show us overfitting with polynomial regression <em>wrong</em>?</p>

<p>Well, apparently they are<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">2</a></sup><sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">3</a></sup>, and this is something we will explore with examples and code in this post in more depth. I first saw it in an X <a href="https://x.com/adad8m/status/1582231644223987712">post</a> by @adad8m, who also has plenty of interesting content in her profile for mathematically inclined readers. Surprisingly, even in this simple case of polynomial features in a linear model, the high-degree polynomials we see in textbooks “overfit” simply becaue they are fit incorrectly - using the standard basis. We will explore other polynomial bases, that are available as simple NumPy functions, that memorize the training set and generalize well in absence of any regularization. It turns out there is more to Ben Recht’s post than meets the eye: the way overfitting is taught, as some tradeoff between model complexity and generalization, is nonexistant not only for various neural network families, but also for simple polynomial fitting models! The code for this post is available in a <a href="https://github.com/alexshtf/alexshtf.github.io/blob/master/assets/free_polynomial_monster.ipynb">notebook</a>, and you can try it out yourself. Since it’s been a while since I posted about polynomial features, I will attempt to make this post a bit more self contained than previous posts in this series.</p>

<h1 id="function-fitting">Function fitting</h1>

<p>Let’s start with a small exercise - of function fitting. Here is a simple function that should be quite challenging for a polynomial to fit:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="k">def</span> <span class="nf">func</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="n">z</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">pi</span> <span class="o">*</span> <span class="n">x</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">cbrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">z</span><span class="p">))</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">sign</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
</code></pre></div></div>

<p>The reason why it is challenging is is due to its shape:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="n">plot_xs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_xs</span><span class="p">,</span> <span class="n">func</span><span class="p">(</span><span class="n">plot_xs</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging.png" alt="polyfit_challenging" /></p>

<p>Indeed, its slopes around \(x=-0.25\) and \(x=0.75\) are practically vertical, so any polynomial will have a hard time fitting it. We can now use it to generate noisy training data:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">noisy_func</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">noise</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">func</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">noise</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>

<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">n</span> <span class="o">=</span> <span class="mi">50</span>
<span class="n">x</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="n">n</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">noisy_func</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</code></pre></div></div>

<p>This is what it looks like:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_xs</span><span class="p">,</span> <span class="n">func</span><span class="p">(</span><span class="n">plot_xs</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">'function'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="s">'o'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'data'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging_data.png" alt="polyfit_challenging_data" /></p>

<h1 id="fitting-polynomials">Fitting polynomials</h1>
<p>To fit the polynomial to our training data, we use the standard least-squares solved from NumPy:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="n">degree</span><span class="p">,</span> <span class="n">feature_matrix_fn</span><span class="p">):</span>
    <span class="c1"># generate polynomial features
</span>    <span class="n">X</span> <span class="o">=</span> <span class="n">feature_matrix_fn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">degree</span><span class="p">)</span>
    
    <span class="c1"># compute coefficients using the L2 loss
</span>    <span class="n">poly</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">lstsq</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">rcond</span><span class="o">=-</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
    
    <span class="c1"># compute training error (RMSE)
</span>    <span class="n">train_rmse</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">X</span> <span class="o">@</span> <span class="n">poly</span> <span class="o">-</span> <span class="n">y</span><span class="p">)))</span>
    
    <span class="c1"># return coefficients and training error
</span>    <span class="k">return</span> <span class="n">poly</span><span class="p">,</span> <span class="n">train_rmse</span>
</code></pre></div></div>

<p>The <code class="language-plaintext highlighter-rouge">fit</code> function is a bit generic - it accepts a <code class="language-plaintext highlighter-rouge">feature_matrix_fn</code> that transforms each training sample \(x\) into a row vector of polynomials, such as the vector \((1, x, x^2, \dots, x^n)\) for polynomials of degree \(n\), and concatenates these rows in a matrix.</p>

<p>Beyond fitting, we will also need to measure the test error of our fit polynomials, and plot them. To that end, we simply measure the average root mean-squared error between the fit polynomial and the true function at 10,000 points:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">test_fit</span><span class="p">(</span><span class="n">degree</span><span class="p">,</span> <span class="n">feature_matrix_fn</span><span class="p">,</span> <span class="n">coefs</span><span class="p">):</span>
    <span class="n">xtest</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">10000</span><span class="p">)</span>
    <span class="n">ytest</span> <span class="o">=</span> <span class="n">feature_matrix_fn</span><span class="p">(</span><span class="n">xtest</span><span class="p">,</span> <span class="n">degree</span><span class="p">)</span> <span class="o">@</span> <span class="n">coefs</span>
    <span class="n">test_rmse</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">ytest</span> <span class="o">-</span> <span class="n">func</span><span class="p">(</span><span class="n">xtest</span><span class="p">))))</span>
    <span class="k">return</span> <span class="n">xtest</span><span class="p">,</span> <span class="n">ytest</span><span class="p">,</span> <span class="n">test_rmse</span>
</code></pre></div></div>

<p>Now it’s time to discuss why we need this genericity in the form of <code class="language-plaintext highlighter-rouge">feature_matrix_fn</code>. It’s a good time to remind ourselves a thing or two about polynomial fitting we learned in this series polynomial features.</p>

<p>Polynomials do not have to be necessarily represented using the standard basis \(\{1, x, x^2, \dots, x^n\}\). For example, consider the polynomial</p>

\[p(x) = 1+2x+3x^2−5x^3 \tag{S}\]

<p>But the <em>same</em> polynomial can also written as 
\(p(x) = 2−x+2(1.5x^2−0.5)−2(2.5x^3−1.5x). \tag{L}\)</p>

<p>In equation (S) above it’s written in terms of the standard basis and has the coefficients \((1, 2, 3, -5)\), whereas in equation (L) it’s written in terms of the basis \(\{1, x, 1.5x^2-0.5, 2.5x^3-1.5\}\) and has the coefficients \((2, -1, 2, -2)\). This is, by the way, the well-known <a href="https://en.wikipedia.org/wiki/Legendre_polynomials">Legendre polynomial basis</a>, and we shall use it extensively in this post. In general a polynomial can be written as an inner product of two vectors,
\(p(x) = \langle \mathbf{P}(x), \mathbf{w} \rangle= \sum_{i=0}^n P_i(x) w_i,\)</p>

<p>where \(\mathbf{P}(x)\) is the vector of basic polynomials, and \(\mathbf{w}\) is the vector of its coefficients. Obviously, \(p(x)\) is a linear function of its coefficients. So when learning a polynomial, the basic polynomials are the <em>features</em>, and the coefficients are the <em>learned parameters</em> of a linear model.</p>

<p>The reason for the genericity in <code class="language-plaintext highlighter-rouge">fit</code> is, of course, our desire fit polynomials with different bases \(\mathbf{P}\) to demonstrate a point. For least-squares fitting, the data matrix rows contain the values of the basis functions at each point in the training set. Fortunately, NumPy comes with functions to generate such matrices for a variety of polynomial bases.</p>

<h1 id="visualizing-the-fit-polynomials">Visualizing the fit polynomials</h1>
<p>Now let’s try to visually inspect polynomials of various degrees that fit our function. Below is a function that fits a polynomial using a given basis, computes the train and test errors, and plots the fitting results and the polynomial coefficients:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">fit_and_plot</span><span class="p">(</span><span class="n">degree</span><span class="p">,</span> <span class="n">feature_matrix_fn</span><span class="p">):</span>
    <span class="n">poly</span><span class="p">,</span> <span class="n">train_rmse</span> <span class="o">=</span> <span class="n">fit</span><span class="p">(</span><span class="n">degree</span><span class="p">,</span> <span class="n">feature_matrix_fn</span><span class="p">)</span>
    <span class="n">xtest</span><span class="p">,</span> <span class="n">ytest</span><span class="p">,</span> <span class="n">test_rmse</span> <span class="o">=</span> <span class="n">test_fit</span><span class="p">(</span><span class="n">degree</span><span class="p">,</span> <span class="n">feature_matrix_fn</span><span class="p">,</span> <span class="n">poly</span><span class="p">)</span>
    <span class="n">coef_sum</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">poly</span><span class="p">)</span>
    <span class="n">fig</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
    <span class="n">fig</span><span class="p">.</span><span class="n">suptitle</span><span class="p">(</span><span class="sa">f</span><span class="s">'Degree: </span><span class="si">{</span><span class="n">degree</span><span class="si">}</span><span class="s">, Train RMSE = </span><span class="si">{</span><span class="n">train_rmse</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">g</span><span class="si">}</span><span class="s">, Test RMSE = </span><span class="si">{</span><span class="n">test_rmse</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">g</span><span class="si">}</span><span class="s">, Coef Sum = </span><span class="si">{</span><span class="n">coef_sum</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">g</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>

    <span class="n">axs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">scatter</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
    <span class="n">axs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">plot</span><span class="p">(</span><span class="n">xtest</span><span class="p">,</span> <span class="n">ytest</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'royalblue'</span><span class="p">)</span>
    <span class="n">axs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">plot</span><span class="p">(</span><span class="n">xtest</span><span class="p">,</span> <span class="n">func</span><span class="p">(</span><span class="n">xtest</span><span class="p">),</span> <span class="n">color</span><span class="o">=</span><span class="s">'red'</span><span class="p">)</span>
    <span class="n">axs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">set_yscale</span><span class="p">(</span><span class="s">'asinh'</span><span class="p">)</span>
    <span class="n">axs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">'Model'</span><span class="p">)</span>

    <span class="n">markerline</span><span class="p">,</span> <span class="n">stemlines</span><span class="p">,</span> <span class="n">baseline</span> <span class="o">=</span> <span class="n">axs</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">stem</span><span class="p">(</span><span class="n">poly</span><span class="p">)</span>
    <span class="n">stemlines</span><span class="p">.</span><span class="n">set_linewidth</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="n">markerline</span><span class="p">.</span><span class="n">set_markersize</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
    <span class="n">axs</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">'Coefficients'</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p>As expected, it also has the genericity to specify our basis. So let’s try plotting with the standard basis. The feature matrix of the standard basis is constructed using  the <code class="language-plaintext highlighter-rouge">np.polynomial.polynomial.polyvander</code> function. The “vander” in the name is because the polynomial feature matrix is known as the <a href="https://en.wikipedia.org/wiki/Vandermonde_matrix">Vandermonde matrix</a>. Let’s fit a polynomial of degree 1:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fit_and_plot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">polyvander</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging_deg1.png" alt="polyfit_challenging_deg1" /></p>

<p>As expected, we got a line, we got a line. The test error is 0.6457. Now let’s try a polynomial of degree 5:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fit_and_plot</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">polyvander</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging_std_deg5.png" alt="polyfit_challenging_std_deg5" /></p>

<p>It fits the function better, and the test error is smaller - 0.2118. Now let’s try a polynomial of degree 49. This degree is exactly the “interpolation threshold” - the polynomial degree that can memorize the training set:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fit_and_plot</span><span class="p">(</span><span class="mi">49</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">polyvander</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging_std_deg49.png" alt="polyfit_challenging_std_deg49" /></p>

<p>Well, it appears that we’re observing what ML 101 textbooks tell us - we’re overfitting! The polynomial is far away from the function, the train error is almost zero, since we’re exactly fitting the training data up to floating point errors, but the test error is \(\sim 6.7 \times 10^{7}\)! Beyond the interpolation threshold, our polynomial is <em>over-parameterized</em>, meaning it has more parameters than needed to memorize the training set. So what about a polynomial of degree 10,000? Let’s try!</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fit_and_plot</span><span class="p">(</span><span class="mi">10000</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">polyvander</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging_std_deg10000.png" alt="polyfit_challenging_std_deg10000" /></p>

<p>Even worse! But we also observe something suspicious - the coefficients of the high degree polynomial are also large - so is it “overfitting”, or is it something else? Well, let’s try the same exercise with a different basis - the Legendre basis. Its feature matrix is implemented in the <code class="language-plaintext highlighter-rouge">np.polynomial.legendre.legvander</code> function. Here is a polynomial of degree 5.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fit_and_plot</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging_leg_deg5.png" alt="polyfit_challenging_leg_deg5" /></p>

<p>Looks identical to the standard basis. If we think about it - it’s not a surprise. There is a unique least-squares fitting polynomial of degree 5, and it doesn’t matter how we represent it, standard basis, or Legendre basis. Let’s try a degree of 49 - the interpolation threshold:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fit_and_plot</span><span class="p">(</span><span class="mi">49</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging_leg_deg49.png" alt="polyfit_challenging_leg_deg49" /></p>

<p>Also looks almost identical to the standard basis. The train error is almost zero. The test error is awful. This is also not a surprise - there is also a unique polynomial of degree 49, the one that exactly passes through the 50 points. So in this case it also doesn’t matter which basis we use. But what happens if we crank up the degree to 10,000? Well, let’s try:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fit_and_plot</span><span class="p">(</span><span class="mi">10000</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging_leg_deg10000.png" alt="polyfit_challenging_leg_deg10000" /></p>

<p>Whoa! That’s interesting! This extremely over-parameterized polynomial both exactly fits the training data points and is pretty close to the true function. The test error is also not bad - \(0.2156\). What happened to our overfitting from ML 101 textbooks? There is no regularization. No control of the degree. But “magically” our high degree polynomial is not that bad! Also look at the coefficients - they are pretty small. We’ll take a deeper look at the coefficients later, but you may notice they appear to follow some trend of decay: coefficients of higher degrees become smaller.</p>

<p>So no it’s time to discuss the two bases. We learned in this series that each basis is coupled with a corresponding “operating region” where it possesses a set of mathematical properties that make it “work well”. I am leaving the meaning of “work well” vague on purpose, but in general it means that it has the right properties to accurately fit functions from a finite amount of training samples in this operating region.  For example, in earlier posts we saw that the Bernstein has the interval \([0, 1]\) as its operating region. Moreover, its key property is that its coefficients allow control of the shape of the function.  The Legendre basis we just used here has the interval \([-1, 1]\) as its operating region. It turns out to be useful in exactly the opposite scenario - when we <em>do not</em> wish to impose any direct control over its shape. We will try to explain why using intuitive tools later in this post. For the standard basis, the operating region is <em>the complex unit circle</em>.  And it’s not a surprise - it’s the foundation of the entire field of Fourier analysis!</p>

<p>It turns out that ML textbooks that use high degree polynomials to demonstrate the balance between “model complexity” and “generalization error” are <em>wrong</em>. The reason is simply the usage of the standard basis. It’s operating reagion is the complex unit circle, it does not possess the mathematical properties required to fit functions of real numbers. The textbooks simply use the wrong tool for polynomial regression, or as Ben Recht pointed in his post, use the wrong features in their linear model!</p>

<p>In this small demo I chose the training and test samples to come from \([-1, 1]\) - this is exactly the operating region of the Legendre basis. In practice, when using high degree polynomial features, you have to first normalize your raw numerical features to the operating region of your basis of choice. You can use standard tools, such as min-max scaling with clipping, or a normalization function such as \(x \to \tanh(\alpha x+ \beta)\). You should <em>never</em> use polynomial features outside the operating region of the basis of your choice!</p>

<h1 id="error-as-a-function-of-the-degree">Error as a function of the degree</h1>

<p>Here we saw examples of polynomials of degree 1, 5, 49, and 10,000. But what about the other degrees? Well, let’s plot the test error as a function of the degree and see! Here is a function that plots the train and test errors for a range of polynomial degrees, and also shows the “interpolation threshold” - when the number of parameters equals the number of training poins:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">plot_errors</span><span class="p">(</span><span class="n">feature_matrix_fn</span><span class="p">):</span>
    <span class="c1"># define a set of degrees that look nice in a plot - linearly spaced
</span>    <span class="c1"># low degrees, and geometrically spaced high degrees.
</span>    <span class="n">degs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">r_</span><span class="p">[</span>
        <span class="n">np</span><span class="p">.</span><span class="n">sort</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">unique</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">15</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">))),</span>
        <span class="n">np</span><span class="p">.</span><span class="n">sort</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">unique</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">geomspace</span><span class="p">(</span><span class="n">n</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="mi">100000</span><span class="p">,</span> <span class="mi">20</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)))</span>
    <span class="p">]</span>

    <span class="c1"># compute train and test errors
</span>    <span class="n">train_errors</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">degs</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">float</span><span class="p">)</span>
    <span class="n">test_errors</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">degs</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">float</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">deg</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">tqdm</span><span class="p">(</span><span class="n">degs</span><span class="p">)):</span>
        <span class="n">poly</span><span class="p">,</span> <span class="n">train_errors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">fit</span><span class="p">(</span><span class="n">deg</span><span class="p">,</span> <span class="n">feature_matrix_fn</span><span class="p">)</span>
        <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">test_errors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">test_fit</span><span class="p">(</span><span class="n">deg</span><span class="p">,</span> <span class="n">feature_matrix_fn</span><span class="p">,</span> <span class="n">poly</span><span class="p">)</span>

    <span class="c1"># plot the train and test errors, and the interpolation threshold vertical
</span>    <span class="c1"># bar.
</span>    <span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">degs</span><span class="p">,</span> <span class="n">train_errors</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'Train'</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">degs</span><span class="p">,</span> <span class="n">test_errors</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'Test'</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">axvline</span><span class="p">(</span><span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'royalblue'</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'dotted'</span><span class="p">,</span>
                <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">axhline</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">test_errors</span><span class="p">),</span> <span class="n">color</span><span class="o">=</span><span class="s">'olive'</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'dashed'</span><span class="p">,</span>
                <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">'Min RMSE = </span><span class="si">{</span><span class="n">np</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">test_errors</span><span class="p">)</span><span class="si">:</span><span class="p">.</span><span class="mi">3</span><span class="n">g</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">([</span><span class="o">-</span><span class="mf">1e-3</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">test_errors</span><span class="p">)])</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">'asinh'</span><span class="p">,</span> <span class="n">linear_width</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">xscale</span><span class="p">(</span><span class="s">'log'</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'Degree'</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'Root Mean Squared Error'</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p>Now let’s plot the errors for the Legendre basis:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plot_errors</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_errors_leg.png" alt="polyfit_errors_leg" /></p>

<p>We nicely see the double-descent phenomenon! At the interpolation threshold, the train error drops towards zero, whereas the test error skyrockets. But as the degrees increase, the test error goes down again! What about the standard basis?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plot_errors</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">polyvander</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_errors_std.png" alt="polyfit_errors_std" /></p>

<p>Below the interpolation threshold, the behavior is identical to the Legendre basis below the interpolation threshold. This is because there is only <em>one</em> least-squares fitting polynomial of any degree below 50. But as we cross the interpolation threshold, the train error goes down towards zero, whereas the test error skyrockets. But this is not because of some magical phenomenon called “overfitting” - this is because the standard basis is simply the wrong tool for polynomial fitting.</p>

<p>If we look again at the Legendre polynomial errors plot, we will see that still the low-degree polynomial achieves a better test error than any high-degree polynomial. So why is it interesting that high degree polynomials do not overfit, if the low degree polynomial is better? If you scroll up, you can see that we samples 50 training points. But if we re-run our entire simulation with 600 training points, the plot for the Legendre basis errors is quite different:</p>

<p><img src="https://alexshtf.github.io/assets/polyfit_errors_leg_600pts.png" alt="polyfit_errors_leg_600pts" /></p>

<p>This time, the high degree polynomials generalize even better than any low degree polynomial can. This is similar to what we observe in large neural networks - in the “big data” regime, when we have huge amounts of data, larger models perform better. This is another nail in the coffin of the popular belief that higher model complexity leads to worse generalization! This isn’t true even for polynomial function fitting - the ones that are used to demonstrate “overfitting” in ML textbooks. It’s not about model complexity - it’s about structure!</p>

<h1 id="what-makes-the-standard-basis-bad-and-the-legendre-basis-good">What makes the standard basis “bad”, and the Legendre basis “good”?</h1>

<p>There are mathematically rigorous reasons, pointed out in Schaeffer’s paper on double descent<sup id="fnref:1:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">2</a></sup>, but here we are aiming for a more intuitive explanation. When fitting with the standard basis, our feature matrix for polynomials of degree \(n\) looks like this:</p>

\[\mathbf{X} = \begin{pmatrix}
1 &amp; x_1 &amp; x_1^2 &amp; \dots &amp; x_1^n \\
1 &amp; x_2 &amp; x_2^2 &amp; \dots &amp; x_2^n \\
&amp;&amp; \vdots &amp; &amp; \\
1 &amp; x_m &amp; x_m^2 &amp; \dots &amp; x_m^n \\
\end{pmatrix}\]

<p>Intuitively, any two even powers, such as \(x^4\) and \(x^6\), are very similar: both grow quickly as \(x\) gets farther from the origin. The same similarity happens to any two odd powers. This means that the columns of the matrix \(X\) above are <em>highly correlated</em>. As degrees get higher and higher, we’re practically beginning to add almost redundant columns that our linear model has to use as features. Intuitively, a lot of “non-informative” features is what makes the linear model behave badly.  This is formally analyzed in Schaeffer’s paper, in the form of the singular values of the matrix \(\mathbf{X}\). In essence, this is also one of the reasons the coefficients of high degree polynomials found using the standard basis were large - such matrices, called <a href="https://en.wikipedia.org/wiki/Condition_number#Matrices">ill conditioned matrices</a>, challenge the algorithms used to fit least-squares models.</p>

<p>The Legendre basis is different. Let’s plot the first four polynomials, of degree 0, 1, 2, and 3:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">handles</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_xs</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">(</span><span class="n">plot_xs</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">(</span><span class="n">handles</span><span class="o">=</span><span class="n">handles</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="p">[</span><span class="sa">f</span><span class="s">'Degree </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">'</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">)])</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/legendre_basis.png" alt="legendre_basis" /></p>

<p>We can see that these polynomials <em>oscilate</em> in \([-1, 1]\)! Higher degrees oscilate more than lower degrees. But not only do they oscilate, but each polynomial oscilates “in different places” than polynomials of lower degrees, meaning it tries to curve up where polynomials of lower degrees curve down, and vice versa. This is formally expressed using the <strong>orthogonality</strong> property of Legendre polynomials - for any two Legendre basis polynomials, \(P_i(x)\) and \(P_j(x)\) of degrees \(i \neq j\), we have:</p>

\[\langle P_i, P_j \rangle = \int_{-1}^1 P_i(x) P_j(x) dx = 0.\]

<p>Integrals are, of course, just “infinite sums”. Therefore, for enough uniformly sampled points \(x_1, \dots, x_m\), we will have</p>

\[\sum_{k=1}^m P_i(x_k) P_j(x_k) \approx 0\]

<p>Why is it interesting? Well, look at the feature matrix for Legendre polynomials:</p>

\[\mathbf{X} = \begin{pmatrix}
P_0(x_1) &amp; P_1(x_1) &amp; \dots &amp; P_n(x_1) \\
P_0(x_2) &amp; P_1(x_2) &amp; \dots &amp; P_n(x_2) \\
 &amp; &amp; \vdots &amp; &amp; \\
P_0(x_m) &amp; P_1(x_m) &amp; \dots &amp; P_n(x_m) \\ 
\end{pmatrix}\]

<p>The orthogonality property means its columns have a little chance to be correlated! Intuitively, it means that adding more and more columns, corresponding to higher and higher degrees, introduces more information the model did not previously had. Fitting a linear model with “informative features”, of course, has a much higher chance of success. But is this informativeness enough? Well, it is not! There are infinitely many polynomials of degree 10,000 that exactly memorize a training set of 50 points. We know there are also <em>bad</em> polynomials of degree 10,000 that memorize the training set - we just found one using the standard basis. And since the Legendre basis is a basis, it can also represent this bad polynomial.</p>

<p>Out of the infinite set of high degree polynomials that exactly memorize the training set, our NumPy least-squares solver chooses only <em>one</em> of them - it chooses the one whose coefficients have the smallest Euclidean norm. The least-squares solver has a “preference” for low-norm coefficients. Exactly this <em>interplay</em> between the way we represent our polynomial and the preference of our optimizer that facilitates this good generalization of high-degree Legendre polynomials.</p>

<p>Let’s try to understand this interplay a bit more. As pointed out above - higher degree Legendre polynomials oscilate more. So we can think of them as a kind of a “frequency domain” - coefficients of higher degrees capture the tendency for more rapid oscilations. The preference for low norm solutions will make the coefficients as small as possible, while still memorizing the training set. That’s why we saw this “decay” of the Legendre polynomial coefficients - the least-squares solver found a polynomial that oscilates as little as possible, while still memorizing the training set. We need the low-degree coefficients to capture the overall shape of the function, but the high degree coefficients, that correspond to rapid oscilations, can be made small, since they only large enough to capture the small deviations from this overall shape.</p>

<p>As evidence, let’s take the polynomial of degree 10,000 we just fit using the Legendre basis, and truncate its degree by using only the first \(k\) basis functions:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">truncated_fit_plots</span><span class="p">(</span><span class="n">degree</span><span class="o">=</span><span class="mi">10000</span><span class="p">,</span> 
                       <span class="n">truncates</span><span class="o">=</span><span class="p">[</span><span class="mi">5</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">40</span><span class="p">],</span>
                       <span class="n">feature_matrix_fn</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">):</span>
    <span class="c1"># fit a full degree polynomial, and produce data for plotting
</span>    <span class="n">poly</span><span class="p">,</span> <span class="n">train_rmse</span> <span class="o">=</span> <span class="n">fit</span><span class="p">(</span><span class="n">degree</span><span class="p">,</span> <span class="n">feature_matrix_fn</span><span class="p">)</span>

    <span class="c1"># data to plot of the full degree polynomial
</span>    <span class="n">xtest</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">10000</span><span class="p">)</span> 
    <span class="n">ytest</span> <span class="o">=</span> <span class="n">feature_matrix_fn</span><span class="p">(</span><span class="n">xtest</span><span class="p">,</span> <span class="n">degree</span><span class="p">)</span> <span class="o">@</span> <span class="n">poly</span>

    <span class="c1"># create subplots
</span>    <span class="n">n_rows</span> <span class="o">=</span> <span class="mi">2</span>
    <span class="n">n_cols</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="n">ceil</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">truncates</span><span class="p">)</span> <span class="o">/</span> <span class="n">n_rows</span><span class="p">))</span>
    <span class="n">fig</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span>
        <span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">n_cols</span> <span class="o">*</span> <span class="n">fig_width</span><span class="p">,</span> <span class="n">n_rows</span> <span class="o">*</span> <span class="n">fig_height</span><span class="p">),</span> 
        <span class="n">layout</span><span class="o">=</span><span class="s">'constrained'</span><span class="p">)</span>
    
    <span class="c1"># plot each truncate together with the full degree polynomial
</span>    <span class="k">for</span> <span class="n">trunc_deg</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">truncates</span><span class="p">,</span> <span class="n">axs</span><span class="p">.</span><span class="n">flatten</span><span class="p">()):</span>
        <span class="n">trunc_poly</span> <span class="o">=</span> <span class="n">poly</span><span class="p">[:(</span><span class="n">trunc_deg</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
        <span class="n">ytest_truncated</span> <span class="o">=</span> <span class="n">feature_matrix_fn</span><span class="p">(</span><span class="n">xtest</span><span class="p">,</span> <span class="n">trunc_deg</span><span class="p">)</span> <span class="o">@</span> <span class="n">trunc_poly</span>

        <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xtest</span><span class="p">,</span> <span class="n">ytest</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'royalblue'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'full'</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xtest</span><span class="p">,</span> <span class="n">ytest_truncated</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'red'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">'truncated'</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">'Truncate deg=</span><span class="si">{</span><span class="n">trunc_deg</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
        <span class="n">handles</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">ax</span><span class="p">.</span><span class="n">get_legend_handles_labels</span><span class="p">()</span>
        
    <span class="n">fig</span><span class="p">.</span><span class="n">legend</span><span class="p">(</span><span class="n">handles</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">loc</span><span class="o">=</span><span class="s">'outside right upper'</span><span class="p">)</span>
    <span class="n">fig</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
   
<span class="n">truncated_fit_plots</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging_trunc.png" alt="polyfit_challenging_trunc" /></p>

<p>This is a kind of <em>pruning</em> - we are effectively zeroing-out the coefficients of the higher degrees, and using only the lower degree coefficients. We can see that the lowest degrees indeed capture the overall shape, and higher degrees begin to capture the fine deviations from this overall shape towards the noisy dataset. The model actually <em>needs</em> a lot of parameters to be able to learn to differentiate signal from noise!</p>

<p>Deep learning is no different. All optimizers used in practice for deep learning, such as SGD, Adam, or AdamW have some “preference”, just like NumPy’s least-squares solver has a preference for small norm solutions. And just like with our polynomials, it is exactly the interplay between the structure of many deep neural network families and optimizer preference that facilitates this double-descent with neural networks and allows us to scale them to huge sizes without losing generalization power. To the best of my knowledge a good theory has yet to be discovered, but just looking up “overparametrized networks” or “double descent” in your favorite search engine for academic papers will yield <em>tons</em> of literature! Personally, I believe that deeper understanding of this phenomenon with simpler models, such as polynomial fitting, may yield a better theory for deep learning.</p>

<h1 id="but-extrapolation-polynomials-dont-extrapolate-well">But extrapolation! Polynomials don’t extrapolate well!</h1>
<p>A common claim is that polynomials “go crazy” if you use them outside of the domain where your training data comes from.  Well, let’s try fitting a our function using training data in \([-0.5, 0.5]\), and plotting it together with the fit polynomial in \([-1, 1]\):</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">n</span> <span class="o">=</span> <span class="mi">25</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="n">n</span><span class="p">)</span> <span class="o">-</span> <span class="mf">0.5</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">noisy_func</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

<span class="n">fit_and_plot</span><span class="p">(</span><span class="mi">10000</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">legvander</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/polyfit_challenging_leg_deg10000_extrapolate.png" alt="polyfit_errors_leg_600pts" /></p>

<p>Indeed, we do not have training data beyond \([-0.5, 0.5]\), so we have no way of knowing the true behavior of the function outside of this interval. The best thing we can expect is some “graceful” behavior. Indeed, our super high degree polynomial behaves quite gracefully - it decays towards zero as we get farther away from \([-0.5, 0.5]\). Not bad for extrapolating!</p>

<p>Of course, if we go outside the operating region of the Legendre basis the polynomial will not be graceful at all. It will quickly explode towards infinity. But that’s the whole point - as long as you always normalize your features to the operating region of your polynomial basis - you should expect graceful extrapolation behavior. There is actually no problem “extrapolating” outside of the domain where most of your training data came from, as long as you stay inside the operating region of your basis.</p>

<h1 id="the-ml-community-hasnt-caught-up">The ML community hasn’t caught up</h1>
<p>It turns out that what we saw here has already been discovered a long time ago by the differential equations community. Many natural phenomena are modeled by differential equations, meaning equations whose variable is a function. Oftentimes, the solution cannot be expressed analytically and is approximated.</p>

<p>One popular approximation method is using polynomials, which allow expressing the problem at hand using a set of linear equations in the polynomial coefficients: some of the equations stem from the laws of physics, whereas the others come from (possibly noisy) measurements. The function is then approximated by finding the coefficients stemming from exact fit to the laws of physics, and least-squares fit to the noisy data. Essentially, this is a kind of machine learning.</p>

<p>It turns out thay the differential equations community, and the numerical analysis community in general, already did extensive research on approximating functions using polynomials. It has been known for a very long time that families of orthogonal polynomials “work well”, and this is described extensively in John Boyd’s book<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">4</a></sup>. It is unfortunate, but knowledge doesn’t always flow between scientific disciplines, and this is one of those cases. Otherwise, ML textbooks and courses wouldn’t use polynomial regression to demonstrate what is “overfitting”.</p>

<h1 id="recap">Recap</h1>
<p>Although polynomials are typically used to demonstrate the need to balance model complexity and generalization, it is a myth. It turns out that the “overfitting” phenomena in this case are just bad numerical behavior of the standard basis, and misunderstanding of the concept of the “operating region” of polynomial bases.</p>

<p>Moreover, typically ML theory deals alot with model classes, also called hypothesis classes. It attempts to answer questions such as: what is is the generalization power of linear regression functions? What about linear classifiers? But it’s not only about model classes. The standard basis of degree 10,000 and the Legendre basis of degree 10,000 represent the same class of models - the class of polynomials of degree 10,000. But we see radically different results with two representations of the same class of models!</p>

<p>The representation, of course, is part of the learning algorithm itself. We can learn our 10,000 degree polynomial in many ways. This is also not a surprise - the same can be said about linear classifiers. A linear classifier can be trained as logistic regression or as a support vector machine - the  learning algorithms may have different generalization power, just like two different polynomial bases. This is despite the fact that both yield exactly the same family of models - the family of linear classifiers.</p>

<p>In the next post we shall try to create a scikit-learn component that generates Legendre basis for numerical features, and use it on some real-world datasets. Let’s see what happens when we crank-up the degree on something more serious than just fitting a polynomial curve to noisy samples from a function!</p>

<p>In later posts we shall see how we can use the preference of least-squares solvers towards small norm solutions to facilitate some control, and study another orthogonal polynomial basis - the <a href="https://en.wikipedia.org/wiki/Chebyshev_polynomials">Chebyshev polynomial basis</a>. You can already take the notebook from this post and try it out yourself - its feature matrix can be constructed using the  <code class="language-plaintext highlighter-rouge">numpy.polynomial.chebyshev.chebvander</code> function. You will see that it also exhibits double descent. But it’s a bit different, and we shall take a deeper look at the difference between these two bases in the context of machine learning in later posts.</p>

<h1 id="references">References</h1>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:4" role="doc-endnote">
      <p>Belkin, Mikhail, et al. “Reconciling modern machine-learning practice and the classical bias–variance trade-off.” <em>Proceedings of the National Academy of Sciences</em> 116.32 (2019): 15849-15854. <a href="#fnref:4" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:1" role="doc-endnote">
      <p>Schaeffer, Rylan, Mikail Khona, Zachary Robertson, Akhilan Boopathy, Kateryna Pistunova, Jason W. Rocks, Ila Rani Fiete, and Oluwasanmi Koyejo. “Double descent demystified: Identifying, interpreting &amp; ablating the sources of a deep learning puzzle.” <em>arXiv preprint arXiv:2303.14151</em> (2023). <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:1:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>Philipp Benner. “Double descent”. <a href="https://github.com/pbenner/double-descent">https://github.com/pbenner/double-descent</a> (2022) <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:3" role="doc-endnote">
      <p>John P. Boyd. “Chebyshev and Fourier Spectral Methods, 2d. edition”. Dover Publishers (2001). <a href="#fnref:3" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><category term="machine learning" /><category term="feature engineering" /><category term="polynomials" /><category term="polynomial regression" /><category term="double descent" /><category term="overparameterization" /><category term="generalization" /><category term="Legendre polynomials" /><category term="Chebyshev polynomials" /><category term="Fourier features" /><summary type="html"><![CDATA[Overparameterized polynomial regression can exhibit double descent: past the interpolation threshold, some bases memorize and still generalize. Experiments compare power, Legendre, Chebyshev, and Fourier features.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="https://alexshtf.github.io/assets/polyfit_challenging_trunc.png" /><media:content medium="image" url="https://alexshtf.github.io/assets/polyfit_challenging_trunc.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Shape restricted function models via polyhedral cones</title><link href="https://alexshtf.github.io/2024/11/09/Shape-Restricted-Models-Polyhedral.html" rel="alternate" type="text/html" title="Shape restricted function models via polyhedral cones" /><published>2024-11-09T00:00:00+00:00</published><updated>2024-11-09T00:00:00+00:00</updated><id>https://alexshtf.github.io/2024/11/09/Shape-Restricted-Models-Polyhedral</id><content type="html" xml:base="https://alexshtf.github.io/2024/11/09/Shape-Restricted-Models-Polyhedral.html"><![CDATA[<h1 id="intro">Intro</h1>

<p>In the <a href="/2024/10/14/Shape-Restricted-Models.html">previous post</a> we learned a technique to use a neural network to produce the <em>coefficients</em> of a function in a basis of our choice. This “embedding vector” of coefficients can be regarded as a sequence of real numbers, whose shape affects the shape of the function the model represents.  We wrote simple layers that can produce an increasing or a decreasing sequence of numbers - that was easy. But what about a convex sequence? Or a sequence that is both increasing and concave?</p>

<p>In this post we will dig deeper into how we can constrain the output of our embedding vectors to satisfy a family of constraints known as <em>polyhedral constraints</em>, that include the above use-cases. Then, after implementing a small PyTorch framework for the basic ideas presented here, we shall connect it to the previous post by fitting functions having interesting shape constraints.</p>

<p>The reason it’s a post and not a paper is, again, because the idea is <em>not</em> novel, and this post has been largely inspired by an idea introduced in an ECCV paper <sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup> to solve a computer vision problem. Here we won’t be solving a computer vision problem, but fitting shape-constrained functions. The focus here is on applying the same idea for a different purpose, and demonstrating it on a concrete, tangible, and accessible use-case, with runnable code in a <a href="https://github.com/alexshtf/alexshtf.github.io/blob/master/assets/polyhedral_constraints.ipynb">notebook</a> you can deploy to Colab. So let’s get started!</p>

<h1 id="polyhedral-sets">Polyhedral sets</h1>

<p>Consider the non-decreasing sequence</p>

\[u_1 \leq u_2 \leq \dots \leq u_n.\]

<p>The above is a compact way to write a system of linear inequalities:</p>

\[\begin{align*}
u_2 - u_1 &amp;\geq 0 \\
u_3 - u_2 &amp;\geq 0 \\
&amp;\vdots \\
u_n - u_{n-1} &amp;\geq 0
\end{align*}\]

<p>The system can also be written in <em>matrix</em> form:</p>

\[\underbrace{\begin{pmatrix}
1 &amp; -1 &amp; 0 &amp; 0 &amp; \dots &amp; 0 \\
0 &amp; 1 &amp; -1 &amp; 0 &amp; \dots &amp; 0 \\
\vdots &amp;  &amp; \ddots &amp; \ddots &amp; &amp; \vdots \\
0 &amp; 0 &amp; 0 &amp; \dots &amp; 1 &amp; -1
\end{pmatrix}}_{\mathbf{A}}
\underbrace{
\begin{pmatrix}
u_1 \\ \vdots \\ \vdots \\ u_n
\end{pmatrix}}_{\mathbf{u}}
\geq 
\begin{pmatrix}
0 \\ \vdots \\ 0 
\end{pmatrix}\]

<p>It turns out sets of all vectors \(\mathbf{u}\) satisfying constraints of the above form have been extensively studied. They are called <em>polyhedral cones</em>. So let’s get formally introduced: for a matrix \(\mathbf{A}\), consider the set \(C(\mathbf{A})\) defined by:</p>

\[C(\mathbf{A}) = \left\{ \mathbf{u} : \mathbf{A} \mathbf{u} \geq \mathbf{0}  \right\}\]

<p>Let’s decrypt the name <em>polyhedral cone</em>. Why <em>polyhedral</em>? A polyhedron, in mathematics, is a high dimensional generalisation of what we know as a polygon: it’s something that has ‘vertices’, or other forms of ‘sharp boundaries’. We will indeed soon see that such sets do have such sharp boundaries. Why <em>cones</em>? It’s also a generalisation of what we know as a 2D or a 3D cone - a set of infinite rays, as depicted below:</p>

<p><img src="https://alexshtf.github.io/assets/cone.png" alt="cone" /></p>

<p>Mathematically, a cone is a set such that if \(\mathbf{u}\) is in the set, then also \(\lambda \mathbf{u}\) is in the set for any \(\lambda \geq 0\). This means that any vector \(\mathbf{u}\) in the cone signifies a <em>direction</em> of a ray belonging to the cone. In a sense, a cone is defined by the set its rays. Our set \(C(\mathbf{A})\) is indeed a cone, since if the inequality \(\mathbf{A} \mathbf{u} \geq 0\) is satisfied by some vector \(\mathbf{u}\), then it’s satisfied by any non-negative multiple of \(\mathbf{u}\).</p>

<p>It turns out that non-decreasing, non-increasing, convex, or concave constraints are all polyhedral conic constraints, since they can be written as linear inequalities with a zero on the right-hand side. We saw non-decreasing constraints above. What about convexity? Well, we just need to require the discrete analogue of the second derivative to be non-negative:</p>

\[u_{i-1} - 2 u_i + u_{i+1} \geq 0 \qquad \forall i = 2, \dots, n-1.\]

<p>Concavity is similar - the discrete analogue of the second derivative is non-positive, or equivalently, its negation is non-negative:</p>

\[-u_{i-1} + 2 u_i - u_{i-1} \geq 0 \qquad \forall i = 2, \dots, n - 2.\]

<p>But what makes polyhedral cones special and useful in machine learning? It’s the following version of the <em>Weyl-Minkowski theorem</em><sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup><sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup> for polyhedral cones :</p>

<blockquote>
  <p>There is \(\mathbf{A} \in \mathbb{R}^{m \times n}\) such that
\(C = \{ \mathbf{u} :  \mathbf{A} \mathbf{u} \geq 0 \}\)
if and only there are vectors \(\mathbf{r}_1, \dots, \dots, \mathbf{r}_p\) such that
\(C = \{ t_1 \mathbf{r}_1 + \dots + t_p \mathbf{r}_p :t_i \geq 0 \}\)</p>

</blockquote>

<p>Namely, a polyhedral cone can be represented either by linear inequalities with a zero right-hand side, or by a linear combination with non-negative coefficients \(\mathbf{u} =  t_1 \mathbf{r}_1 + \dots + t_p \mathbf{r}_p\). The vectors \(\mathbf{r}_1, \dots, \mathbf{r}_p\) are called the <em>generators</em> or <em>extremal rays</em> of the set. Generators, because they are used to generate any point in the set. Extremal, because these are exactly the “sharp boundaries” of \(C(\mathbf{A})\), as depicted below. The black vectors are the extremal rays of the green cone (source: Wikipedia):</p>

<p><img src="https://alexshtf.github.io/assets/polyhedral-cone.png" alt="polyhedral_cone" /></p>

<p>To make things more concise, we can embed the generators into the <em>columns</em> of the matrix \(\mathbf{R}\), and write:</p>

\[\mathbf{u} = \mathbf{R} \mathbf{t}\]

<p>The theorem ensures that for every \(\mathbf{A}\) representing the inequalities, we have a corresponding \(\mathbf{R}\) representing the generator rays, and vice versa.</p>

<p>But why an ML practitioner would be interested in this result? A machine-learned model can easily produce a non-negative vector \(\mathbf{t}\) using known activation functions, such as <code class="language-plaintext highlighter-rouge">ReLU</code> or <code class="language-plaintext highlighter-rouge">SoftPlus</code>. Then, we can use a linear layer with matrix \(\mathbf{R}\) to compute  a vector \(\mathbf{R} \mathbf{t}\) in our desired cone. Therefore, a concatenation of a non-negativity layer and a linear layer with an appropriate matrix can be thought of as a “polyhedral cone” layer: it generates vectors that lie in a polyhedral cone of our choice. This is illustrated below:</p>

<p><img src="https://alexshtf.github.io/assets/polyedral_cone_layer.png" alt="polyedral_cone_layer" /></p>

<p>Note, that in this case \(\mathbf{R}\) is not <em>learned</em>, but rather is a constant matrix specifically designed to generate the cone we need. So the reason polyhedral cones are useful in ML is because we can make sure that our neural network produces vectors that lie in the cone, <em>by design</em>, using elementary tools that every ML practitioner knows: <code class="language-plaintext highlighter-rouge">ReLU</code> activations and linear layers!</p>

<p>To end our discussion of the generator ray representation, we note that we can have two generators that are just the negation of each other, i.e \(\mathbf{r}_i = -\mathbf{r}_j\). For example, consider the matrix:</p>

\[\mathbf{R} = \begin{pmatrix}
1 &amp; -1 &amp; 3 \\
2 &amp; -2 &amp; -2 \\
3 &amp; -3 &amp; 0
\end{pmatrix}\]

<p>Its first two columns are just negations of each other. This matrix represents a cone that is generated by</p>

\[\begin{align*}
\mathbf{R} \mathbf{t} &amp;= t_1 \cdot \begin{pmatrix}1 \\ 2 \\ 3 \end{pmatrix} + t_2 \cdot \begin{pmatrix}-1 \\ -2 \\ -3 \end{pmatrix} + t_3 \cdot \begin{pmatrix} 3 \\ -2 \\ 0 \end{pmatrix} \\
 &amp;= (t_1 - t_2)\cdot  \begin{pmatrix}1 \\ 2 \\ 3\end{pmatrix} + t_3 \cdot \begin{pmatrix}3 \\ -2 \\ 0 \end{pmatrix}
\end{align*}\]

<p>for any \(t_1, t_2, t_3 \geq 0\). Since \(t_1 - t_2\) can be any real number, positive or negative - two columns having opposite signs can just be ‘shrunk’ into one column, whose coefficient does not have to be non-negative. \(t_3\), of course, remains non-negative.</p>

<p>So in practice, when we are given a matrix \(\mathbf{R}\) of generators for our cone, we should be also have corresponding instructions regarding each component of  \(\mathbf{t}\) - is it non-negative, or arbitrary. This little details seems like a complication, but we shall soon see that the contrary is true. The components which do not have to be non-negative will be referred to as the <em>linear</em> components. Those that have to be non-negative are the <em>conic</em> components.</p>

<p>Now let’s implement our ‘polyhedral cone layer’, as depicted in the illustration above. Since not all generator coefficients have to be non-negative, we shall first implement a ‘masked activation’ layer, that applies an activation function only to a subset of the components denoted by a mask. It’s a bit intricate - since we need to somehow apply activation to a subset of the components of our input vector, and keep the remaining components unchanged. We achieve it with the help of the <code class="language-plaintext highlighter-rouge">torch.masked_scatter</code> function - it’s a bit advanced, but it works. Finally, it shall be easier to use this class by specifying the mask of the components for which we do <em>not</em> apply the activation. So here it is:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="k">class</span> <span class="nc">MaskedActivation</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">):</span>
    <span class="s">"""
    Applies activation to (potentially) a subset of the input components. 
    Args:
      mask: The mask of coordinates to which we should NOT apply the activation.
        `None` means applying the activation to all components. It is assumed
        that `mask` is a 1D tensor, that applies to the last dimension of the input
      activation: The activation to apply.
    """</span>
    <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span><span class="p">()</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s">'mask'</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
  
  <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
    <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">mask</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
      <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">activation</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> 
    <span class="k">else</span><span class="p">:</span>
      <span class="n">activated</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">activation</span><span class="p">(</span><span class="n">x</span><span class="p">[...,</span> <span class="o">~</span><span class="bp">self</span><span class="p">.</span><span class="n">mask</span><span class="p">])</span>
      <span class="n">result</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">masked_scatter</span><span class="p">(</span><span class="o">~</span><span class="bp">self</span><span class="p">.</span><span class="n">mask</span><span class="p">,</span> <span class="n">activated</span><span class="p">)</span>
      <span class="k">return</span> <span class="n">result</span>
</code></pre></div></div>

<p>To represent our generators, we shall also need a helper function to create a linear layer with constant (or frozen) weight matrix:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">frozen_linear</span><span class="p">(</span><span class="n">weights</span><span class="p">):</span>
  <span class="n">in_dim</span><span class="p">,</span> <span class="n">out_dim</span> <span class="o">=</span> <span class="n">weights</span><span class="p">.</span><span class="n">shape</span>
  <span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_dim</span><span class="p">,</span> <span class="n">out_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
  <span class="n">layer</span><span class="p">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">weights</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
  <span class="k">return</span> <span class="n">layer</span>
</code></pre></div></div>

<p>And now we can create our polyhedral cone layer - a concatenation of our activation with our generators:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">polyhedral_cone_module</span><span class="p">(</span><span class="n">generators</span><span class="p">,</span> <span class="n">linear_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">nonneg_activation</span><span class="o">=</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">):</span>
  <span class="n">generators</span> <span class="o">=</span> <span class="n">generators</span> <span class="o">/</span> <span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">vector_norm</span><span class="p">(</span><span class="n">generators</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
  <span class="k">return</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
  	<span class="n">MaskedActivation</span><span class="p">(</span><span class="n">linear_mask</span><span class="p">,</span> <span class="n">nonneg_activation</span><span class="p">),</span>
    <span class="n">frozen_linear</span><span class="p">(</span><span class="n">generators</span><span class="p">)</span>
  <span class="p">)</span>
</code></pre></div></div>

<p>This is where it’s obvious why our <code class="language-plaintext highlighter-rouge">MaskedActivation</code> layer accepts a mask telling it where to <em>avoid</em> applying the activation - this is exactly the linear mask. You may also have noticed that the first line in the function above does something strange - it normalizes the columns of the generator matrix. But why? Well, in practice, algorithms for training ML models like ‘nice’ normalized data and ‘nice’ normalized matrices. Since the generators only denote the <em>direction</em> of the cone’s rays, their length does not change the cone they generate. I’ll save you the time spent on first writing this entire blog post without the normalization step, encountering the numerical issues, and then adding it in to solve numerical issues. So I added this step in the first place.</p>

<p>Now let’s try out our generic polyhedral cone module for non-decreasing sequences. For that, we will need to represent non-decreasing sequences using a generator matrix \(\mathbf{R}\). Luckily, it’s quite simple to do manually. As we saw in the last post, a nondecreasing sequence can be created by specifying its first compoenent by an arbitrary number \(t_1\), and adding to it non-negative numbers \(t_2, d_3, \dots, t_{n} \geq 0\) sequentially to produce the next elements:</p>

\[\begin{pmatrix}u_1 \\ u_2 \\ u_3 \\ \vdots \\ u_n\end{pmatrix}
= 
\begin{pmatrix}t_1 \\ t_1 + t_2 \\ t_1+t_2+t_3 \\ \vdots \\ t_1 + \dots + t_n \end{pmatrix}
= 
\underbrace{\begin{pmatrix}
1 &amp; 0 &amp; 0 &amp; \dots &amp; 0 \\
1 &amp; 1 &amp; 0 &amp; \dots &amp; 0 \\
1 &amp; 1 &amp; 1 &amp; \dots &amp; 0 \\
\vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
1 &amp; 1 &amp; 1 &amp; \dots &amp; 1
\end{pmatrix}}_{\mathbf{R}}
\begin{pmatrix}
t_1 \\ t_2 \\ \vdots \\ t_n
\end{pmatrix}\]

<p>We can do it, alternatively, by specifying the <em>last</em> component by an arbitrary number \(t_n\), and subtracting non-negative numbers ‘backwards’ to generate the non-decreasing sequence:
\(\begin{pmatrix} u_1 \\ u_2 \\ \dots \\ u_n \end{pmatrix}
=
\begin{pmatrix}
t_n - t_{n - 1} - \dots - t_2 - t_1 \\
\vdots \\
t_n - t_{n-1} \\
t_n
\end{pmatrix}
= 
\begin{pmatrix}
-1 &amp; -1 &amp; -1 &amp; \dots &amp; -1 &amp; 1 \\
0 &amp; -1 &amp; -1 &amp; \dots &amp; -1 &amp; 1 \\
 &amp; \vdots &amp;  &amp; \ddots &amp; \vdots &amp; \vdots \\
0 &amp;  0 &amp; \dots &amp; 0  &amp; -1 &amp; 1 \\
0 &amp; 0 &amp; \dots &amp; 0 &amp; 0 &amp; 1
\end{pmatrix}
\begin{pmatrix}t_1 \\ t_2 \\ \vdots \\ t_n \end{pmatrix}\)</p>

<p>This demonstrates that the representation with generators does not have to be unique.</p>

<p>Now let’s construct a PyTorch module that produces non-decreasing matrix using our <code class="language-plaintext highlighter-rouge">polyhedral_cone_module</code> function. Here, we shall use the <em>first</em> matrix \(\mathbb{R}\) - the upper triangular matrix of ones. Note, that since \(t_1\) is a ‘linear’ coefficient, meaning it doesn’t have to be non-negative, so we shall also use the appropriate mask :</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">nondecreasing_module</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
  <span class="n">generators</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tril</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span> <span class="c1"># the matrix of ones under the diagonal
</span>  <span class="n">linear_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
  <span class="n">linear_mask</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
  <span class="k">return</span> <span class="n">polyhedral_cone_module</span><span class="p">(</span><span class="n">generators</span><span class="p">,</span> <span class="n">linear_mask</span><span class="p">)</span>
</code></pre></div></div>

<p>Let’s try it out, and use our module to create some non-decreasing vectors:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">mod</span> <span class="o">=</span> <span class="n">nondecreasing_module</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>

<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">mod</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n</span><span class="p">)))</span>
<span class="k">print</span><span class="p">(</span><span class="n">mod</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n</span><span class="p">)))</span>
<span class="k">print</span><span class="p">(</span><span class="n">mod</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n</span><span class="p">)))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([0.3367, 1.0963, 1.9135, 2.7285, 3.0102, 3.6145, 5.9269, 6.3511])
tensor([0.4617, 1.2974, 2.2933, 3.4709, 4.8660, 5.0353, 5.3516, 6.6343])
tensor([1.3221, 2.5051, 2.8869, 3.2736, 4.8561, 5.9502, 6.4929, 7.6605])
</code></pre></div></div>

<p>They do appear non-decreasing! So now that we have the machinery working, let’s look for a more generic way to derive the matrix \(\mathbf{R}\) that corresponds to a desired polyhedral cone. But we have two obstacles.</p>

<p>The first obstacle is of a combinatorial nature - to represent an \(n\)-dimensional polyhedral cone, we may need more than \(2^n\) generator rays in the worst case. This is a consequence of the McMullen theorem<sup id="fnref:4" role="doc-noteref"><a href="#fn:4" class="footnote" rel="footnote">4</a></sup>. This means that even for a small dimension \(n\), we might have a matrix \(\mathbf{R}\) that has more columns than the number of the atoms in the entire universe. Fortunately, the polyhedral cones in this post to not suffer from this exponential explosion of generator rays. We just saw an example with the increasing cone - the matrix \(\mathbf{R}\) had exactly \(n\) columns.</p>

<p>The second challenge is practical - how do we compute the generator rays of the cone we desire? Do we have a convenient Python library to which we feed the inequalities we want, and it gives us the corresponding matrix \(\mathbf{R}\)? For non-decreasing sequences we could derive it ourselves, but it may be non-trivial in general. As ML practitioners, we just want to write PyTorch layers - we don’t want to deal with computational polyhedral geometry. It turns out we do have such a library, and that’s exactly what we shall explore in the next section.</p>

<h1 id="computing-generator-matrices">Computing generator matrices</h1>

<p>There are two prominent parallel streams of research on algorithms for translating between the generator and the inequality representations of polyhedral cones. One is a family of algorithms based on the so-called Reverse Search Method<sup id="fnref:5" role="doc-noteref"><a href="#fn:5" class="footnote" rel="footnote">5</a></sup>, whereas the other is based on the so-called Double Description method<sup id="fnref:6" role="doc-noteref"><a href="#fn:6" class="footnote" rel="footnote">6</a></sup>. It’s important to acknowledge both, but in this post we shall use the double-description method, because it’s accessible from a simple Python library. So we will not dive into the algorithms, but rather use the library. After all, as ML researchers, we prefer using the results developed by the talented polyhedral geometry researchers, over doing polyhedral geometry research ourselves.</p>

<p>So there is a C library, <code class="language-plaintext highlighter-rouge">cddlib</code>, that was written by Komei Fukua. Its source code can be found on <a href="https://github.com/cddlib/cddlib">GitHub</a>. He is also the author of an interesting open access book on polyhedral computation, which is available <a href="https://doi.org/10.3929/ethz-b-000426218">here</a>.  It turns out his library also has a nice Python wrapper - <code class="language-plaintext highlighter-rouge">pycddlib</code>. So let’s install it in our Colab notebook:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>%pip install pycddlib-standalone
</code></pre></div></div>

<p>Now we can import and use it in our notebook. So let’s describe a polyhedral cone representing increasing sequences, and then use this example to explain the format the <code class="language-plaintext highlighter-rouge">pycddlib</code> library expects its input to be in:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">cdd</span>

<span class="k">def</span> <span class="nf">make_cdd_cone</span><span class="p">(</span><span class="n">A</span><span class="p">):</span>
  <span class="s">"""Creates a libcdd polyhedral cone given the matrix describing the inequalities A x ≥ 0. """</span>
  <span class="c1"># define the RHS of the inequalities. In our case - everything is ≥ 0
</span>  <span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>

  <span class="c1"># the library operates assuming we describe inequalities in the form:
</span>  <span class="c1">#    b + A x ≥ 0
</span>  <span class="c1"># and expects b and A to be concatenated into one big matrix.
</span>  <span class="n">Ab</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">hstack</span><span class="p">([</span><span class="n">b</span><span class="p">,</span> <span class="n">A</span><span class="p">])</span>

  <span class="c1"># create and print the polyhedral cone object
</span>  <span class="n">mat</span> <span class="o">=</span> <span class="n">cdd</span><span class="p">.</span><span class="n">matrix_from_array</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">rep_type</span><span class="o">=</span><span class="n">cdd</span><span class="p">.</span><span class="n">RepType</span><span class="p">.</span><span class="n">INEQUALITY</span><span class="p">)</span>
  <span class="n">poly</span> <span class="o">=</span> <span class="n">cdd</span><span class="p">.</span><span class="n">polyhedron_from_matrix</span><span class="p">(</span><span class="n">mat</span><span class="p">)</span>
  <span class="k">return</span> <span class="n">poly</span>

<span class="c1"># polyhedral cone for non-decreasing sequences:
</span><span class="n">cone</span> <span class="o">=</span> <span class="n">make_cdd_cone</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span>
    <span class="p">[[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">]]))</span>
<span class="k">print</span><span class="p">(</span><span class="n">cone</span><span class="p">)</span>
</code></pre></div></div>
<p>Here is the output:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>begin
 7 9 real
  0 -1  1  0  0  0  0  0  0
  0  0 -1  1  0  0  0  0  0
  0  0  0 -1  1  0  0  0  0
  0  0  0  0 -1  1  0  0  0
  0  0  0  0  0 -1  1  0  0
  0  0  0  0  0  0 -1  1  0
  0  0  0  0  0  0  0 -1  1
end
</code></pre></div></div>

<p>The line 7 9 real means that our cone is defined by 7 inequalities, and each inequality is described by 9 numbers. And indeed, below we have a matrix of 7 rows, each having 9 numbers. The first column of the matrix is all zeros, whereas the other columns are exactly the matrix \(\mathbf{A}\). At this stage, we can think of the first column as the right-hand side of the inequalities, which is zero.</p>

<p>Now this is where the library is useful - it lets is convert inequality form to generator form! So let’s do it:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">generators</span> <span class="o">=</span> <span class="n">cdd</span><span class="p">.</span><span class="n">copy_generators</span><span class="p">(</span><span class="n">cone</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">generators</span><span class="p">)</span>
</code></pre></div></div>

<p>Here is the output:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>V-representation
linearity 1  8
begin
 8 9 real
  0 -1  0  0  0  0  0  0  0
  0 -1 -1  0  0  0  0  0  0
  0 -1 -1 -1  0  0  0  0  0
  0 -1 -1 -1 -1  0  0  0  0
  0 -1 -1 -1 -1 -1  0  0  0
  0 -1 -1 -1 -1 -1 -1  0  0
  0 -1 -1 -1 -1 -1 -1 -1  0
  0  1  1  1  1  1  1  1  1
end
</code></pre></div></div>

<p>The first line, <code class="language-plaintext highlighter-rouge">V-represenation</code>, means that our set is in the generators form, as opposed to the inequality form. The second line, <code class="language-plaintext highlighter-rouge">linearity 1 8</code> means we have <em>one</em> “linear” generator, and it’s the \(8^{\mathrm{th}}\) generator.  Note, that in the printout, the indices begin from 1, rather than from 0.  Then, we have a matrix of generators, again, with a column of zeros. The zero column has a special mathematical meaning, which here we shall interpret as ‘these are generators of a cone’.  Now here is an important detail - in contrast to the mathematical convention in this post, generators are in the <em>rows</em>, rather than the columns of the matrix. To convert this matrix to our desired form, we need to <em>transpose</em> it.</p>

<p>The generators object above also has a <code class="language-plaintext highlighter-rouge">array</code> property with the generators, and the <code class="language-plaintext highlighter-rouge">lin_set</code> property specifying the set of linear generators. So let’s create a convenience function to print the generators in the columns of a matrix, together with the set of linear generators:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">print_generators</span><span class="p">(</span><span class="n">cone</span><span class="p">):</span>
  <span class="n">generators</span> <span class="o">=</span> <span class="n">cdd</span><span class="p">.</span><span class="n">copy_generators</span><span class="p">(</span><span class="n">cone</span><span class="p">)</span>
  <span class="k">print</span><span class="p">(</span><span class="s">'Linear generators: '</span><span class="p">,</span> <span class="n">generators</span><span class="p">.</span><span class="n">lin_set</span><span class="p">)</span>
  <span class="k">print</span><span class="p">(</span><span class="s">'Generator matrix: '</span><span class="p">)</span>
  <span class="n">gen_mat</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">generators</span><span class="p">.</span><span class="n">array</span><span class="p">)</span>
  <span class="n">gen_mat</span> <span class="o">=</span> <span class="n">gen_mat</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:]</span> <span class="c1"># discard the first column of zeros
</span>  <span class="n">gen_mat</span> <span class="o">=</span> <span class="n">gen_mat</span><span class="p">.</span><span class="n">T</span>
  <span class="k">print</span><span class="p">(</span><span class="n">gen_mat</span><span class="p">)</span>

<span class="n">print_generators</span><span class="p">(</span><span class="n">cone</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Linear generators:  {7}
Generator matrix: 
[[-1. -1. -1. -1. -1. -1. -1.  1.]
 [ 0. -1. -1. -1. -1. -1. -1.  1.]
 [ 0.  0. -1. -1. -1. -1. -1.  1.]
 [ 0.  0.  0. -1. -1. -1. -1.  1.]
 [ 0.  0.  0.  0. -1. -1. -1.  1.]
 [ 0.  0.  0.  0.  0. -1. -1.  1.]
 [ 0.  0.  0.  0.  0.  0. -1.  1.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.]]

</code></pre></div></div>

<p>Indeed, it’s one of the two matrices \(\mathbf{R}\) we <em>manually</em> obtained for non-decreasing sequences. Note, that in the <code class="language-plaintext highlighter-rouge">lin_set</code> property, the linear indices are zero-based. But now we can see that the process does not have to be manual - it can be <em>automated</em>.</p>

<p>What about a non-increasing sequence? Well, we could derive it ourselves easily, but why bother? Let’s do it ourselves:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># print generators for non-increasing sequences
# note - the matrix is exactly the negated matrix of non-decreasing sequences
# we used above.
</span><span class="n">cone</span> <span class="o">=</span> <span class="n">print_generators</span><span class="p">(</span><span class="n">make_cdd_cone</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span>
    <span class="p">[[</span><span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">]])))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Linear generators:  {7}
Generator matrix: 
[[1. 1. 1. 1. 1. 1. 1. 1.]
 [0. 1. 1. 1. 1. 1. 1. 1.]
 [0. 0. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 1.]]
</code></pre></div></div>

<p>Well, we can see the pattern, right? It’s just an upper-triangular matrix of ones. So we can also implement a corresponding PyTorch layer:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">nonincreasing_module</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
  <span class="n">generators</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">triu</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span>

  <span class="n">linear_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
  <span class="n">linear_mask</span><span class="p">[</span><span class="n">n</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span> <span class="c1"># the last generator is linear
</span>
  <span class="k">return</span> <span class="n">polyhedral_cone_module</span><span class="p">(</span><span class="n">generators</span><span class="p">,</span> <span class="n">linear_mask</span><span class="p">)</span>
</code></pre></div></div>

<p>Let’s see if it works:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">mod</span> <span class="o">=</span> <span class="n">nonincreasing_module</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>

<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">mod</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n</span><span class="p">)))</span>
<span class="k">print</span><span class="p">(</span><span class="n">mod</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n</span><span class="p">)))</span>
<span class="k">print</span><span class="p">(</span><span class="n">mod</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n</span><span class="p">)))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([ 1.2874,  0.9507,  0.8596,  0.7242,  0.6091,  0.6091,  0.6091, -0.2256])
tensor([2.1994, 1.7378, 1.5487, 1.2399, 0.8352, 0.3387, 0.3387, 0.3387])
tensor([3.0661, 1.7440, 1.1661, 1.1661, 1.1661, 0.5613, 0.2811, 0.2811])
</code></pre></div></div>

<p>Appears to be working!  What about convex sequences? For that, we will require the second-order differences of the sequence to be non-negative:</p>

\[u_{i + 1} - 2 u_i + u_{i-1} \geq 0, \qquad i = 1, \dots, n-1\]

<p>So let’s compute the generator form and see the pattern:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># print generators for a convex sequence
</span><span class="n">cone</span> <span class="o">=</span> <span class="n">print_generators</span><span class="p">(</span><span class="n">make_cdd_cone</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span>
    <span class="p">[[</span><span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">2</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">2</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">2</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">2</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">2</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>  <span class="o">-</span><span class="mi">2</span><span class="p">,</span>  <span class="mi">1</span><span class="p">]])))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Linear generators:  {6, 7}
Generator matrix: 
[[ 1.  2.  3.  4.  5.  6.  7. -6.]
 [ 0.  1.  2.  3.  4.  5.  6. -5.]
 [ 0.  0.  1.  2.  3.  4.  5. -4.]
 [ 0.  0.  0.  1.  2.  3.  4. -3.]
 [ 0.  0.  0.  0.  1.  2.  3. -2.]
 [ 0.  0.  0.  0.  0.  1.  2. -1.]
 [ 0.  0.  0.  0.  0.  0.  1.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.]]
</code></pre></div></div>

<p>It takes a few seconds to see the pattern, but it’s not hard. The generator matrix is divided into four blocks:</p>

\[\left(
\begin{array}{ccccc|c}
1 &amp; 2 &amp; 3 &amp; \dots &amp; n - 1 &amp; -(n - 2) \\
0 &amp; 1 &amp; 2 &amp; \dots &amp; n - 2 &amp; -(n - 3) \\
\vdots &amp; \ddots &amp; \ddots &amp; 1 &amp; 2 &amp; -1 \\
0 &amp; 0 &amp; \dots &amp; 0 &amp; 1 &amp; 0 \\
\hline
0 &amp; 0 &amp; \dots &amp; 0 &amp; 0 &amp; 1
\end{array}
\right)\]

<p>The top-left block is an upper-triangular matrix of simple progressions from 1 to \(n - k\). The bottom-left block is a row of zeros. The top-right block is a decreasing column from \(-(n-2)\) to zero, and the bottom-right block is the scalar 1. Beyond the block structure, we see that the last <em>two</em> generators, indexed 6,7, are the linear generators. So let’s implement a function that creates such a generator matrix:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">make_convex_generators</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
  <span class="n">top_left</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">triu</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">((</span><span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  <span class="n">top_right</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="n">n</span><span class="o">-</span><span class="mi">2</span><span class="p">),</span> <span class="mi">1</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  <span class="n">top</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">top_left</span><span class="p">,</span> <span class="n">top_right</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

  <span class="n">bottom_left</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
  <span class="n">bottom_right</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
  <span class="n">bottom</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">bottom_left</span><span class="p">,</span> <span class="n">bottom_right</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

  <span class="n">generators</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">top</span><span class="p">,</span> <span class="n">bottom</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
  <span class="k">return</span> <span class="n">generators</span>

<span class="n">make_convex_generators</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7., -6.],
        [ 0.,  1.,  2.,  3.,  4.,  5.,  6., -5.],
        [ 0.,  0.,  1.,  2.,  3.,  4.,  5., -4.],
        [ 0.,  0.,  0.,  1.,  2.,  3.,  4., -3.],
        [ 0.,  0.,  0.,  0.,  1.,  2.,  3., -2.],
        [ 0.,  0.,  0.,  0.,  0.,  1.,  2., -1.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.]])
</code></pre></div></div>

<p>Appears to work! So now let’s implement our PyTorch module:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">convex_module</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
  <span class="n">generators</span> <span class="o">=</span> <span class="n">make_convex_generators</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
  <span class="n">linear_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
  <span class="n">linear_mask</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">:]</span> <span class="o">=</span> <span class="bp">True</span>
  <span class="k">return</span> <span class="n">polyhedral_cone_module</span><span class="p">(</span><span class="n">generators</span><span class="p">,</span> <span class="n">linear_mask</span><span class="p">)</span>
</code></pre></div></div>

<p>Now let’s test it. Since convexity is easier to see visually, we will plot the resulting sequences to see if they’re indeed convex:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="n">n</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">convex_module</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>

<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">56</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">6</span><span class="p">):</span>
  <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
  <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">layer</span><span class="p">(</span><span class="n">t</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/cdd_convex_sequences.png" alt="cdd_convex_sequences" /></p>

<p>They indeed appear all convex, but something look suspicious! Why are they all almost the same? Maybe our <code class="language-plaintext highlighter-rouge">cddlib</code> library does not compute the right generators for <em>entire</em> space the convex sequences, but only for a small subset? It turns out what we see is a result of our use of <em>normally distributed</em> random inputs to our layer. This specific distribution of vectors indeed produces a very specific distribution of convex sequences. But we can get interesting sequences with different inputs, for example:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n</span><span class="p">)).</span><span class="n">square</span><span class="p">()</span>
<span class="n">t</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">n</span> <span class="o">/</span> <span class="mi">2</span>
<span class="n">t</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="o">-</span><span class="n">n</span> <span class="o">/</span> <span class="mi">2</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">layer</span><span class="p">(</span><span class="n">t</span><span class="p">))</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/cdd_convex_sequence_interesting.png" alt="cdd_convex_sequence_interesting" /></p>

<p>Indeed looks very different. This means that any model that uses our layer to generate convex sequences will have to learn to provide those ‘interesting’ input vectors \(\mathbf{t}\) as the input to our layer, so that the output of our layer is the correct one.</p>

<p>Concave sequences are extremely similar, the inequality is just the negation of the convexity inequality:</p>

\[-u_{i-1} + 2 u_i - u_{i+1} \geq 0, \qquad i = 2, \dots, n - 1\]

<p>Let’s ask the <code class="language-plaintext highlighter-rouge">cddlib</code> library to create a generator matrix for us:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># print generators for a convex sequence
</span><span class="n">cone</span> <span class="o">=</span> <span class="n">print_generators</span><span class="p">(</span><span class="n">make_cdd_cone</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span>
    <span class="p">[[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">]])))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Linear generators:  {6, 7}
Generator matrix: 
[[-1. -2. -3. -4. -5. -6.  7. -6.]
 [ 0. -1. -2. -3. -4. -5.  6. -5.]
 [ 0.  0. -1. -2. -3. -4.  5. -4.]
 [ 0.  0.  0. -1. -2. -3.  4. -3.]
 [ 0.  0.  0.  0. -1. -2.  3. -2.]
 [ 0.  0.  0.  0.  0. -1.  2. -1.]
 [ 0.  0.  0.  0.  0.  0.  1.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.]]
</code></pre></div></div>

<p>The pattern appears similar to the convex sequences, but the structure of the blocks is a bit different. And here too the last two generators are linear. It’s not hard to figure out that the following code produces this matrix directly with PyTorch:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">make_concave_generators</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
  <span class="n">top_left</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">triu</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">((</span><span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  <span class="n">top_right</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="n">n</span><span class="o">-</span><span class="mi">2</span><span class="p">),</span> <span class="mi">1</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  <span class="n">top</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">top_left</span><span class="p">,</span> <span class="n">top_right</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

  <span class="n">bottom_left</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
  <span class="n">bottom_right</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
  <span class="n">bottom</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">bottom_left</span><span class="p">,</span> <span class="n">bottom_right</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

  <span class="n">mat</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">top</span><span class="p">,</span> <span class="n">bottom</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
  <span class="n">mat</span><span class="p">[:</span><span class="n">n</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="p">:</span><span class="n">n</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span> <span class="o">*=</span> <span class="o">-</span><span class="mi">1</span>
  <span class="k">return</span> <span class="n">mat</span>

<span class="n">make_concave_generators</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([[-1., -2., -3., -4., -5., -6.,  7., -6.],
        [-0., -1., -2., -3., -4., -5.,  6., -5.],
        [-0., -0., -1., -2., -3., -4.,  5., -4.],
        [-0., -0., -0., -1., -2., -3.,  4., -3.],
        [-0., -0., -0., -0., -1., -2.,  3., -2.],
        [-0., -0., -0., -0., -0., -1.,  2., -1.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.]])
</code></pre></div></div>

<p>Now we can implement and try out the concave module:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">concave_module</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
  <span class="n">generators</span> <span class="o">=</span> <span class="n">make_concave_generators</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
  <span class="n">linear_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
  <span class="n">linear_mask</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">:]</span> <span class="o">=</span> <span class="bp">True</span>
  <span class="k">return</span> <span class="n">polyhedral_cone_module</span><span class="p">(</span><span class="n">generators</span><span class="p">,</span> <span class="n">linear_mask</span><span class="p">)</span>
</code></pre></div></div>

<pre><code class="language-python3">n = 32
layer = concave_module(n)

torch.manual_seed(42)
for i in range(10):
  t = torch.randn(n)
  plt.plot(layer(t))
plt.show()
</code></pre>

<p><img src="https://alexshtf.github.io/assets/cdd_concave_sequences.png" alt="cdd_concave_sequences" /></p>

<p>Indeed all look concave!</p>

<p>So what if we want a concave and non-decreasing function? Easy peasy! Need to satisfy both the types of inequalities:</p>

\[\begin{align*}
-u_{i-1} + 2 u_i - u_{i+1} \geq 0 &amp; \qquad i = 2, \dots, n - 1 \\
u_{i+1} - u_i \geq 0 &amp; \qquad i = 1, \dots, n - 1
\end{align*}\]

<p>So we just concatenate the two inequality matrices we saw above, one after the other, and ask <code class="language-plaintext highlighter-rouge">cddlib</code> to make the appropriate generator matrix:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># print generators for concave non-decreasing sequences
</span><span class="n">print_generators</span><span class="p">(</span><span class="n">make_cdd_cone</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span>
    <span class="p">[[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">2</span><span class="p">,</span>  <span class="o">-</span><span class="mi">1</span><span class="p">],</span>
     <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>   <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>   <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>   <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>   <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>   <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>  <span class="mi">1</span><span class="p">,</span>   <span class="mi">0</span><span class="p">],</span>
     <span class="p">[</span> <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>   <span class="mi">1</span><span class="p">]]</span>
    <span class="p">)))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Linear generators:  {7}
Generator matrix: 
[[-7. -6. -5. -4. -3. -2. -1.  1.]
 [-6. -5. -4. -3. -2. -1.  0.  1.]
 [-5. -4. -3. -2. -1.  0.  0.  1.]
 [-4. -3. -2. -1.  0.  0.  0.  1.]
 [-3. -2. -1.  0.  0.  0.  0.  1.]
 [-2. -1.  0.  0.  0.  0.  0.  1.]
 [-1.  0.  0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.]]
</code></pre></div></div>

<p>The pattern is quite simple - a triangular block of decreasing integers, concatenated to a column of ones. And only the last generator is linear. The following PyTorch code generates such a matrix:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">concave_nondecreasing_generators</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
  <span class="n">generators</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">).</span><span class="n">triu</span><span class="p">(</span><span class="mi">1</span><span class="p">).</span><span class="n">cumsum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">).</span><span class="n">fliplr</span><span class="p">()</span>
  <span class="n">generators</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
  <span class="k">return</span> <span class="n">generators</span>

<span class="n">concave_nondecreasing_generators</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([[-7., -6., -5., -4., -3., -2., -1.,  1.],
        [-6., -5., -4., -3., -2., -1., -0.,  1.],
        [-5., -4., -3., -2., -1., -0., -0.,  1.],
        [-4., -3., -2., -1., -0., -0., -0.,  1.],
        [-3., -2., -1., -0., -0., -0., -0.,  1.],
        [-2., -1., -0., -0., -0., -0., -0.,  1.],
        [-1., -0., -0., -0., -0., -0., -0.,  1.],
        [-0., -0., -0., -0., -0., -0., -0.,  1.]])
</code></pre></div></div>

<p>Here is our concave-nondecreasing layer:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">concave_nondecreasing_cone_module</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
  <span class="n">generators</span> <span class="o">=</span> <span class="n">concave_nondecreasing_generators</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
  <span class="n">linear_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
  <span class="n">linear_mask</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
  <span class="k">return</span> <span class="n">polyhedral_cone_module</span><span class="p">(</span><span class="n">generators</span><span class="p">,</span> <span class="n">linear_mask</span><span class="p">)</span>
</code></pre></div></div>

<p>Let’s try it out with a few random inputs:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">concave_nondecreasing_cone_module</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>

<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
  <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
  <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">layer</span><span class="p">(</span><span class="n">t</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/cdd_concave_nondecreasing_sequences.png" alt="cdd_concave_nondecreasing_sequences" /></p>

<p>Beautiful! Devising the generator matrix for concave and non-decreasing sequences ourselves is not so trivial this time. We may have devised the right generator matrix after thinking about it. I can speak for myself here, but for me it would take a very long time of ‘thinking about it’ until I would come up with the correct set of generators. The algorithms implemented in the<code class="language-plaintext highlighter-rouge">cddlib</code> library automate this process, and produce <em>provably correct</em> generator matrix and linear set.</p>

<p>So now we have an interesting method to constrain the output of a model to lie in our desired polyhedral cone:</p>

<ol>
  <li>Feed <code class="language-plaintext highlighter-rouge">cddlib</code> with the appropriate inequalities, and let it create the generators and the linear set</li>
  <li>Implement PyTorch code that produces the same generator matrix</li>
  <li>Feed the generator matrix to <code class="language-plaintext highlighter-rouge">polyhedral_cone_module</code> to produce a PyTorch layer that transforms arbitrary vectors into vectors that <em>provably</em> lie in the cone we desire.</li>
</ol>

<p>Now let’s combine this with our shape-constrained polynomial functions from the last post, to fit a model that produces shape-constrained functions.</p>

<h1 id="example---fitting-concave-functions">Example - fitting concave functions</h1>

<p>Now let’s remind ourselves <em>why</em> we wanted to constrain the output vector of a neural network to satisfy conic constraints. The reason is because we want to use such vectors as coefficients of a Bernsten polynomial, that inherits the properties of its coefficients, such as monotonicity, or convexity.  We continue the adventure we started in the previous post of fitting increasing, decreasing, convex, or concave functions using neural networks!</p>

<p>So let’s re-use the <code class="language-plaintext highlighter-rouge">BernsteinPolynomialModel</code> class from the previous post, that represents a function \(f(\mathbf{x}, z)\) with a given constraint on the shape of the function, as a function of \(z\). For completeness, I repeat its definition here:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">BernsteinPolynomialModel</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x_model</span><span class="p">,</span> <span class="n">coef_transformer</span><span class="p">):</span>
    <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">x_model</span> <span class="o">=</span> <span class="n">x_model</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">coef_transformer</span> <span class="o">=</span> <span class="n">coef_transformer</span>

  <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
    <span class="n">coefs</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">coef_transformer</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">x_model</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
    <span class="n">degree</span> <span class="o">=</span> <span class="n">coefs</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span>
    <span class="n">basis</span> <span class="o">=</span> <span class="n">bernstein_basis</span><span class="p">(</span><span class="n">degree</span><span class="p">,</span> <span class="n">z</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">coefs</span> <span class="o">*</span> <span class="n">basis</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<p>The restriction of shape comes from the <code class="language-plaintext highlighter-rouge">coef_transformer</code> sub-module, that produces the right shape-constrained coefficients for the Bernstein polynomial basis. Here, we shall use the <code class="language-plaintext highlighter-rouge">concave_module</code> we just implemented using our polyhedral cone library to make sure that our module fits concave functions of \(z\).</p>

<p>For the demo, we shall generate a synthetic data-set using the following funciton:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="k">def</span> <span class="nf">np_softplus</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
  <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">log1p</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>

<span class="k">def</span> <span class="nf">hairy_concave_func</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
  <span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">x3</span> <span class="o">=</span> <span class="n">x</span><span class="p">[...,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">x</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">x</span><span class="p">[...,</span> <span class="mi">2</span><span class="p">]</span>
  <span class="k">return</span> <span class="n">np_softplus</span><span class="p">(</span><span class="n">x1</span><span class="p">)</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">0.05</span> <span class="o">+</span> <span class="n">z</span><span class="p">)</span> <span class="o">-</span> <span class="n">np_softplus</span><span class="p">(</span><span class="n">x3</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">z</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">x2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
</code></pre></div></div>

<p>Looking at it as a function of \(z\), we can see that it is of the form \(a \sqrt{0.05 + z} - b(z - c)^2\). It’s indeed concave: square roots are concave, and a ‘sad’ parabola is also concave. Let’s plot a few examples for various inputs \(\mathbf{x}\):</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">zs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">zs</span><span class="p">,</span> <span class="n">hairy_concave_func</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]),</span> <span class="n">zs</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">'function 1'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">zs</span><span class="p">,</span> <span class="n">hairy_concave_func</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">]),</span> <span class="n">zs</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">'function 2'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">zs</span><span class="p">,</span> <span class="n">hairy_concave_func</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mf">1.5</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">]),</span> <span class="n">zs</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">'function 3'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/concave_examples_to_fit.png" alt="concave_examples_to_fit" /></p>

<p>To generate an entire data-set based on this function, we adopt a similar approach to the previous post - we generate random features \((\mathbf{x}, z)\), and generate labels using \(f(\mathbf{x}, z) + \varepsilon\), where \(\varepsilon\) is normally distributed noise, and \(f\) is our hairy concave function above:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">generate_dataset</span><span class="p">(</span><span class="n">n_rows</span><span class="p">,</span> <span class="n">noise</span><span class="o">=</span><span class="mf">0.02</span><span class="p">,</span> <span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
  <span class="n">xs</span> <span class="o">=</span> <span class="n">std</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n_rows</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">+</span> <span class="n">mean</span>
  <span class="n">zs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="n">n_rows</span><span class="p">)</span>
  <span class="n">labels</span> <span class="o">=</span> <span class="n">hairy_concave_func</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">zs</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n_rows</span><span class="p">)</span> <span class="o">*</span> <span class="n">noise</span>

  <span class="n">xs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">xs</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
  <span class="n">zs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">zs</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
  <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">labels</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
  <span class="k">if</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">is_available</span><span class="p">():</span>
    <span class="n">xs</span> <span class="o">=</span> <span class="n">xs</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>
    <span class="n">zs</span> <span class="o">=</span> <span class="n">zs</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>
    <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>

  <span class="k">return</span> <span class="n">xs</span><span class="p">,</span> <span class="n">zs</span><span class="p">,</span> <span class="n">labels</span>
</code></pre></div></div>

<p>Now, similarly to the previous post, we shall generate train and validation set iterators:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">batch_iter</span> <span class="kn">import</span> <span class="n">BatchIter</span>

<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
<span class="n">train_iter</span> <span class="o">=</span> <span class="n">BatchIter</span><span class="p">(</span><span class="o">*</span><span class="n">generate_dataset</span><span class="p">(</span><span class="mi">50000</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
<span class="n">valid_iter</span> <span class="o">=</span> <span class="n">BatchIter</span><span class="p">(</span><span class="o">*</span><span class="n">generate_dataset</span><span class="p">(</span><span class="mi">10000</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
</code></pre></div></div>

<p>We shall also modify the <code class="language-plaintext highlighter-rouge">make_model</code> function from the previous post, to create a model with <em>concave</em> rather than <em>increasing</em> constraints:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">make_model</span><span class="p">(</span><span class="n">layer_dims</span><span class="p">,</span> <span class="n">constrained</span><span class="o">=</span><span class="bp">True</span><span class="p">):</span>
  <span class="c1"># create a fully connected ReLU network
</span>  <span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
      <span class="n">layer</span>
      <span class="k">for</span> <span class="n">in_dim</span><span class="p">,</span> <span class="n">out_dim</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_dims</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_dims</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
      <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="p">[</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_dim</span><span class="p">,</span> <span class="n">out_dim</span><span class="p">),</span> <span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">()]</span>
  <span class="p">]</span>

  <span class="k">if</span> <span class="n">constrained</span><span class="p">:</span>
    <span class="c1"># define a model for x
</span>    <span class="n">x_model</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>

    <span class="c1"># construct a network for predicting non-decreasing functions
</span>    <span class="c1"># the polynomial degree is the output dimension of the last
</span>    <span class="c1"># layer.
</span>    <span class="k">return</span> <span class="n">BernsteinPolynomialModel</span><span class="p">(</span>
        <span class="n">x_model</span><span class="p">,</span>
        <span class="n">concave_module</span><span class="p">(</span><span class="n">layer_dims</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="c1"># &lt;-- CONCAVE CONSTRAINTS
</span>    <span class="p">)</span>
  <span class="k">else</span><span class="p">:</span>
    <span class="n">layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">layer_dims</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span>
</code></pre></div></div>

<p>Note, that the function above can create both a concavely-constrained model using our <code class="language-plaintext highlighter-rouge">concave_module</code> as the component that produces Bernstein coefficients, and a fully unconstrained model, by using a regular fully-connected layer instead of Bernstein polynomials.</p>

<p>Now let’s train it. To that end, we will reuse the functions <code class="language-plaintext highlighter-rouge">train_model</code>, <code class="language-plaintext highlighter-rouge">train_epoch</code>, and <code class="language-plaintext highlighter-rouge">evaluate_epoch</code> from the previous post. So now let’s use them to train a model that is constrained to concave functions using Bernstein polynomials of degree 3:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">lr</span> <span class="o">=</span> <span class="mf">1e-3</span>
<span class="n">weight_decay</span> <span class="o">=</span> <span class="mf">0.</span>
<span class="n">degree</span> <span class="o">=</span> <span class="mi">20</span>
<span class="n">layer_dims</span> <span class="o">=</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span>
              <span class="mi">4</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="mi">3</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="mi">2</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="n">degree</span><span class="p">]</span>
<span class="n">model</span><span class="p">,</span> <span class="n">val_loss</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span>
    <span class="n">train_iter</span><span class="p">,</span> <span class="n">valid_iter</span><span class="p">,</span> <span class="n">layer_dims</span><span class="p">,</span>
    <span class="n">optim_fn</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">AdamW</span><span class="p">,</span>
    <span class="n">optim_params</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">))</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">cpu</span><span class="p">()</span>
</code></pre></div></div>

<p>I got a validation loss of $0.00094$. Not bad. Now let’s plot one function that our model learned function for a given value of \(\mathbf{x} = (-1, 0.1, 0.5)\), and compare it to the ground truth function:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>

<span class="n">plot_zs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">features</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]).</span><span class="n">repeat</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">func</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">features</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">func</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">'Model function'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">hairy_concave_func</span><span class="p">(</span><span class="n">features</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">plot_zs</span><span class="p">.</span><span class="n">numpy</span><span class="p">()),</span> <span class="n">label</span><span class="o">=</span><span class="s">'True function'</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/fit_concave_func_1.png" alt="fit_concave_func_1" /></p>

<p>Seems close! Remember, that training and validation sets were generated by taking \(\mathbf{x}\) to be normally distributed with zero mean and standard deviation of 1. So what happens if we feed our model with a vector \(\mathbf{x}\) that is unlikely to be generated by our data distribution? Let’s see what happens for \(\mathbf{x} = (-2, -3, 4)\):</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">features</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mf">2.</span><span class="p">,</span> <span class="o">-</span><span class="mf">3.</span><span class="p">,</span> <span class="o">-</span><span class="mf">4.</span><span class="p">]).</span><span class="n">repeat</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">func</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">features</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">func</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">'Model function'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">hairy_concave_func</span><span class="p">(</span><span class="n">features</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">plot_zs</span><span class="p">.</span><span class="n">numpy</span><span class="p">()),</span> <span class="n">label</span><span class="o">=</span><span class="s">'True function'</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/fit_concave_func_2.png" alt="fit_concave_func_2" /></p>

<p>Appears <em>very</em> far away from the true function, but still concave! Even if we feed our model with out-of-distribution data, its predictions may be inaccurate, but they will <strong>always</strong> satisfy the constraint of concavity. It is built into the model by design. And if this constraint is important for a business application, it’s there!</p>

<p>So now let’s train a fully unconstrained model, with similar MLP layer sizes:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">lr</span> <span class="o">=</span> <span class="mf">1e-3</span>
<span class="n">weight_decay</span> <span class="o">=</span> <span class="mf">0.</span>
<span class="n">degree</span> <span class="o">=</span> <span class="mi">20</span>
<span class="n">layer_dims</span> <span class="o">=</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span>
              <span class="mi">4</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="mi">3</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="mi">2</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="n">degree</span><span class="p">]</span>
<span class="n">unconstrained_model</span><span class="p">,</span> <span class="n">val_loss</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span>
    <span class="n">train_iter</span><span class="p">,</span> <span class="n">valid_iter</span><span class="p">,</span> <span class="n">layer_dims</span><span class="p">,</span> <span class="n">constrained</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
    <span class="n">optim_fn</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">AdamW</span><span class="p">,</span>
    <span class="n">optim_params</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">))</span>
<span class="n">unconstrained_model</span> <span class="o">=</span> <span class="n">unconstrained_model</span><span class="p">.</span><span class="n">cpu</span><span class="p">()</span>
</code></pre></div></div>

<p>Note the <code class="language-plaintext highlighter-rouge">constrained=False</code> flag we pass to the <code class="language-plaintext highlighter-rouge">train_model</code> function. I got a validation loss of \(0.00107\). Doesn’t seem far away from our constrained model. Let’s see what this model, that is not constrained to produce concave functions, has learned. First, let’s try the same “likely” vector \(\mathbf{x}=(-1, 0.1, 0.5)\) as above:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">features</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]).</span><span class="n">repeat</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
    <span class="n">plot_zs</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">unconstrained_model</span><span class="p">(</span><span class="n">features</span><span class="p">).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">'Model function'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">hairy_concave_func</span><span class="p">(</span><span class="n">features</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">plot_zs</span><span class="p">.</span><span class="n">numpy</span><span class="p">()),</span> <span class="n">label</span><span class="o">=</span><span class="s">'True function'</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/fit_unconstrained_func_1.png" alt="fit_unconstrained_func_1" /></p>

<p>Seems close to the truth, <em>and</em> concave.  But is it a coincidence? Let’s try the “unlikely” vector \(\mathbf{x}=(-2, -3, -4)\) we tried above:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">features</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mf">2.</span><span class="p">,</span> <span class="o">-</span><span class="mf">3.</span><span class="p">,</span> <span class="o">-</span><span class="mf">4.</span><span class="p">]).</span><span class="n">repeat</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
    <span class="n">plot_zs</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">unconstrained_model</span><span class="p">(</span><span class="n">features</span><span class="p">).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">'Model function'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">hairy_concave_func</span><span class="p">(</span><span class="n">features</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">plot_zs</span><span class="p">.</span><span class="n">numpy</span><span class="p">()),</span> <span class="n">label</span><span class="o">=</span><span class="s">'True function'</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/fit_unconstrained_func_2.png" alt="fit_unconstrained_func_2" /></p>

<p>Whoa! The model’s predicted function is both far away from the true function, and not concave! Well, it was expected - after all, a regular MLP model has no mechanism that ensures its predictions are concave.</p>

<h1 id="summary">Summary</h1>

<p>In this post we explored an interesting technique to force constraints on the output vectors of neural-network layers, when these constraints are <em>polyhedral cone</em> constraints. We use a library, <code class="language-plaintext highlighter-rouge">cddlib</code>, to represent a polyhedral cone using an equivalent “generative” representation, that integrates well with how multi-layer machine learned models are built. This observation allowed us to make sure, <em>by design</em>, that the output of our neural network satisfies the desired property.</p>

<p>In our case, the desired properties were monotonicity, convexity, concavity, or combinations of the above. These are indeed polyhedral cones, and the reason we were interested in them in the first place was constraining continuous functions to have these properties: if the coefficients of a polynomial in the Bernstein basis satisfies these properties, so does the polynomial itself.</p>

<p>The Weyl-Minkowski theorem we used for polyhedral cones is, in fact, more generic. There is a “generative” representation for any set of the form \(\{\mathbf{x} : \mathbf{A} \mathbf{x} \geq \mathbf{b} \}\). In this post we explored only the case of \(\mathbf{b} = \mathbf{0}\). The generator representation for the more general case is a bit more complicated, but not by much. And it also integrates nicely with how PyTorch layers are built. I strongly encourage you to explore it on your own. The <code class="language-plaintext highlighter-rouge">cddlib</code> library supports the generic case, so you don’t need to devise the generators yourself. Just understand enough to parse the library’s output.</p>

<p>Finally, as an interesting side note, we now also understand that the last layer of a ReLU network lies in a cone. This is because the ReLU activation before the last layer creates a non-negative vector, and the linear layer that follows contains the generators. The bias only moves the cone to a point that is not the origin. I don’t know of papers that use this interpretation to do something useful, but if you do, please let me know. It sounds interesting.</p>

<p>That’s it! I hope you learned something new about incorporating constraints into neural networks. I certainly have. For me, writing this short series was extremely enlightening. See you soon!</p>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:1" role="doc-endnote">
      <p>Frerix, T., Nießner, M., &amp; Cremers, D. (2020). Homogeneous linear inequality constraints for neural network activations. In <em>Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops</em> (pp. 748-749). <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>Minkowski, H. (1897). Allgemeine Lehrsätze über die convexen Polyeder. <em>Nachrichten von der Gesellschaft der Wissenschaften zu Göttingen, Mathematisch-Physikalische Klasse</em>, <em>1897</em>, (pp. 198-220). <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:3" role="doc-endnote">
      <p>Weyl, H. (1935). Elementare Theorie der konvexen Polyeder, <em>Comment. Math. Helvetici</em>, 1935, (p7). <a href="#fnref:3" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:4" role="doc-endnote">
      <p>McMullen, P. (1970). The maximum numbers of faces of a convex polytope. <em>Mathematika</em>, <em>17</em>(2), 179-184. <a href="#fnref:4" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:5" role="doc-endnote">
      <p>Avis, D., &amp; Fukuda, K. (1996). Reverse search for enumeration. <em>Discrete applied mathematics</em>, <em>65</em>(1-3), 21-46. <a href="#fnref:5" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:6" role="doc-endnote">
      <p>Fukuda, K., &amp; Prodon, A. (1995, July). Double description method revisited. In <em>Franco-Japanese and Franco-Chinese conference on combinatorics and computer science</em> (pp. 91-111). Berlin, Heidelberg: Springer Berlin Heidelberg. <a href="#fnref:6" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><category term="pytorch" /><category term="machine-learning" /><category term="monotonic-regression" /><category term="bernstein" /><category term="polynomial-regression" /><category term="polyhedral-cone" /><summary type="html"><![CDATA[Enforce richer shape constraints (convexity/concavity and combinations with monotonicity) by constraining coefficient vectors to polyhedral cones. Implement a PyTorch cone layer and fit concave functions.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="https://alexshtf.github.io/assets/polyedral_cone_layer.png" /><media:content medium="image" url="https://alexshtf.github.io/assets/polyedral_cone_layer.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Shape restricted function models</title><link href="https://alexshtf.github.io/2024/10/14/Shape-Restricted-Models.html" rel="alternate" type="text/html" title="Shape restricted function models" /><published>2024-10-14T00:00:00+00:00</published><updated>2024-10-14T00:00:00+00:00</updated><id>https://alexshtf.github.io/2024/10/14/Shape-Restricted-Models</id><content type="html" xml:base="https://alexshtf.github.io/2024/10/14/Shape-Restricted-Models.html"><![CDATA[<h1 id="intro">Intro</h1>

<p>Occasionally in practice we aim to train models that represent a function of restricted shape, when viewed as a function of <em>one</em> of the features. Formally, we are referring to fitting a function \(f(\mathbf{x}, z)\), that is monotone, bounded, convex, or concave in \(z\) for every \(\mathbf{x}\). The feature \(z\) is <em>special</em> in our context - the model \(f\) has a special shape as a function of \(z\). Here are some examples:</p>

<ul>
  <li>\(f(\mathbf{x}, z)\) models insurance premium given features of the policy and the insured person in \(\mathbf{x}\), and the coverage in \(z\). We would like \(f\) to be nondecreasing in \(z\) for every \(\mathbf{x}\):  larger coverage incurs a potentially larger insurance premium.</li>
  <li>\(f(\mathbf{x}, z)\) models the probability of winning an auction described by features \(\mathbf{x}\) and bid \(z\). Here, \(f\) must be bounded between 0 and 1, since it’s a probability, and nondecreasing in \(z\), since higher bids mean potentially chances of winning.</li>
  <li>\(f(\mathbf{x}, z)\) models utility of an investment of \(z\) dollars in a project described by features \(\mathbf{x}\). Here it’s reasonable that \(f\) is nondecreasing and concave, to model ‘diminishing returns’.</li>
</ul>

<p>There is a vast amount of literature on learning \(f(z)\) with constraints on the shape of \(f\) for various families, especially when \(f\) is a polynomial. In fact, there’s an entire field of polynomial optimization devoted just to polynomial shape constraints. See <a href="https://www.youtube.com/playlist?list=PLnEqeh8YM6NbFHDUmWHvVsv7utr9jP-PM">this</a> playlist of video lectures, for a great introduction, or just search the web for the term ‘polynomial optimization.’ However, many of the ideas require specialized ‘acrobatics’ that are hard to implement in commodity ML packages we all love: PyTorch and TensorFlow.</p>

<p>There is  also the idea of <em>Lattice Networks</em><sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>, and a nice TensorFlow library that implements them called <a href="https://www.tensorflow.org/lattice/overview">TensorFlow Lattice</a>. They are designed for modeling functions of the form \(f(\mathbf{x}, \mathbf{z})\), where \(\mathbf{z}\) is a <em>vector</em> comprised of several features for which we want to constraint the shape of \(f\). They are more generic than the idea I present here, but are also more expensive. This post is about a <em>scalar</em> \(z\), meaning that we have only <em>one</em> shape-constrained feature. This lets us do something interesting and specialized for this case.</p>

<p>The idea I present here is probably <em>not</em> new, even though I couldn’t find literature on that. Probably, since I didn’t know what buzzwords to look for. So if you know some prior work I could cite, please let me know!</p>

<p>As customary, the code is available in a <a href="https://github.com/alexshtf/alexshtf.github.io/blob/master/assets/shape_constrained_models.ipynb">notebook</a> you can deploy to Google Colab and play around with. So let’s dive in!</p>

<h1 id="bernstein-polynomials-strike-again">Bernstein polynomials strike again</h1>

<p>We already met Bernstein polynomials in our <a href="/2024/01/21/Bernstein.html">series</a> on polynomial features. So let’s make a short recap of what we learned. Given a degree \(n\), we define the polynomials:</p>

\[b_{i,n}(x) = \binom{n}{i} x^i (1-x)^{n-i}.\]

<p>We can see that each \(b_{i,d}(x)\) is indeed a polynomial function of \(x\) of degree \(n\). Moreover, we learned in the series that <em>any</em> polynomial \(p(x)\) of degree \(n\) can be written as:</p>

\[p(x) = \sum_{i=0}^n a_{i} b_{i,n}(x).\]

<p>In other words, these polynomials are actually a <em>basis</em> for all polynomials of degree \(n\). We also learned in this series that this basis is useful for fitting functions on the unit interval \([0, 1]\) with machine learned models <em>without</em> the polynomials going ‘crazy’ and ‘wiggly’ with simple regularization tricks. Finally, we learned that their coefficients give us direct control over the shape of \(p(x)\), and in particular:</p>

<ul>
  <li>If \(a_0 \leq a_1 \leq \dots \leq a_n\), then \(p(x)\) is nondecreasing on \([0, 1]\).</li>
  <li>If \(a_0 \geq a_1 \geq \dots \geq a_n\), then \(p(x)\) is nonincreasing on \([0, 1]\).</li>
  <li>If \(a_i \in [a, b]\), then \(p(x) \in [a, b]\) for any \(x \in [0, 1]\).</li>
</ul>

<p>In other words, nondecreasing or nonincreasing coefficients yield a nondecreasing or nonincreasing polynomial, and imposing a bound on the coefficiens imposes the corresponding bound on the polynomial.</p>

<p>So the basic idea is simple assuming \(z \in [0, 1]\). Choose a polynomial degree $n$, feed \(\mathbf{x}\) to an <em>arbitrary</em> model that produces the coefficients vector \(\mathbf{a} = (a_0, \dots, a_n)\) having the desired monotonicity properties, and let the model’s output be the corresponding polynomial in the Bernstein basis. The basic flow is illustrated below:</p>

<p><img src="https://alexshtf.github.io/assets/increasing_function_model.png" alt="increasing_function_model" /></p>

<p>Observe that we don’t really care what the model consuming \(\mathbf{x}\) looks like. For all we care, \(\mathbf{x}\) can be a free-form text with a description of an insurance policy, and the model consuming \(\mathbf{x}\) is our super-duper state-of-the-art transformer that understands insurance policies and produces an embedding vector. But the embedding vector is not arbitrary - it’s a coefficient vector for Bernstein polynomials satisfying a desired shape property. Thus, in this example, the model will have to be fine-tuned for the task of producing the appropriate Bernstein coefficients.</p>

<p>The basic idea of learning a model to predict the coefficient vector of a function is <strong>not</strong> new. To the best of my knowledge, it dates back to the 1993 paper of Hastie and Tibarshiani<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup>, and more papers applying the idea appeared over the years<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup><sup id="fnref:4" role="doc-noteref"><a href="#fn:4" class="footnote" rel="footnote">4</a></sup><sup id="fnref:5" role="doc-noteref"><a href="#fn:5" class="footnote" rel="footnote">5</a></sup>. That’s why it’s a blog post, rather than a paper. This is one of those posts where I want to understand something by implementing it, and share my understanding and learning experience with the readers.</p>

<p>Before developing the basic idea into a more concrete framework, let’s recall one more interesting fact we learned in the series about Bernstrin polynomials. The Bernstein coefficients control the polynomial locally, in the vicinity of the points on a <em>grid</em>, or a <em>lattice</em>. In this sense, we can think of this basic idea as an enhancement of one-dimensional lattice networks.</p>

<h1 id="implementing-the-framework-in-pytorch">Implementing the framework in PyTorch</h1>

<p>To implement this idea we need to take care of two details: what happens if \(z\) is <em>not</em> in \([0, 1]\), and how do we generate Bernstein coefficients satisfying our desired properties. Then, we shall implement everything in PyTorch.</p>

<p>First, we discuss what happens if \(z\) that is <em>not</em> in \([0, 1]\). As mahcine learning practitioners we have a pretty standard set of solutions - feature scaling. For example, if \(z\) is assumed to be bounded, we can use simple min-max scaling. For a potentially unbounded, but non-negative feature, such as duration or money, we could scale using \(\tanh\), \(\arctan\), or an algebraic function such as:</p>

\[\phi_a(z) = \frac{a}{a + z}\]

<p>The choice of the scaling function is where our domain knowledge about \(z\) is useful, and this is the “feature engineering” part of our idea. Because feature scaling is typically a part of the data preparation components of a machine learning pipeline, rather than the model, we assume here that our model takes an already scaled \(z\).</p>

<p>Now, let’s discuss ensuring that the ‘embedding vector’ \(\mathbf{a}\) that our model produces has the right properties (monotonicity / boundedness). This can be achieved by stacking an additional ‘coefficient transform’ layer on top of an existing model. For example, if the last layer of a given model produces a vector \(\mathbf{u} = (u_0, \dots, u_n)\), our ‘coefficient transform’ layer produces a nondecreasing \(\mathbf{a}\) as using \(\mathrm{ReLU}\):</p>

\[a_i = u_0 + \sum_{j=1}^i \mathrm{ReLU}(u_j),\]

<p>or using \(\mathrm{SoftPlus}\):</p>

\[a_i = u_0 + \sum_{j=1}^i \mathrm{SoftPlus}(u_j).\]

<p>Below is a \(\mathrm{SoftPlus}\) based implementation:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="k">class</span> <span class="nc">NondecreasingCoefTransform</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">u</span><span class="p">):</span>
        <span class="c1"># We assume that `u` has mini-batch dimensions,
</span>        <span class="c1"># and the 'coefficient' dimension is the last one.
</span>        <span class="n">u_head</span> <span class="o">=</span> <span class="n">u</span><span class="p">[...</span> <span class="p">,</span><span class="mi">0</span><span class="p">:</span><span class="mi">1</span><span class="p">]</span>
        <span class="n">u_tail_relu</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">functional</span><span class="p">.</span><span class="n">softplus</span><span class="p">(</span><span class="n">u</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">:])</span>
        <span class="n">head_tail</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">u_head</span><span class="p">,</span> <span class="n">u_tail_relu</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">head_tail</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<p>Let’s try it out:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">u</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="n">NondecreasingCoefTransform</span><span class="p">()(</span><span class="n">u</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([-5.0000, -1.9514,  0.1755,  0.4888,  3.5374])
</code></pre></div></div>

<p>Now we can stack a <code class="language-plaintext highlighter-rouge">NondecreasingCoefTransform</code> on top of an existing network, and obtain nondecreasing coefficients.</p>

<p>Now let’s proceed to implementing the idea in PyTorch. First, we need to compute the Bernstein basis using PyTorch using vectorized functions that run well on both CPU and GPU. For simplicity, even though it may not be the ‘best’ way to do it, we shall compute the basis <em>by definition</em>.</p>

<p>It turns out that PyTorch does not have a built-in function to compute the binomial coefficient \(\binom{n}{i}\), so let’s implement one.  Implementing it directly may cause overflow, since the binomial coefficient is defined in terms of factorials. Moreover, we would like a vectorized implementation that can take many values of \(i\) at once. It turns out PyTorch does have the right tools, but <em>in logarithmic space</em>, using the <code class="language-plaintext highlighter-rouge">torch.lgamma</code> function that implements the logarithm of the Gamma function. Recall, that the Gamma function generalizes the factorial, since for an integer \(n\) we have:</p>

\[\Gamma(n + 1) = n!\]

<p>Therefore,</p>

\[\ln\left(\binom{n}{i}\right) = \ln\left( \frac{n!}{k!(n-k)!} \right) = \ln(\Gamma(n+1)) - \ln(\Gamma(k+1)) - \ln(\Gamma(n - k + 1))\]

<p>So the code for the binomial coefficient in log-space is:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>

<span class="k">def</span> <span class="nf">log_binom_coef</span><span class="p">(</span><span class="n">n</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
  <span class="k">return</span> <span class="p">(</span>
      <span class="n">torch</span><span class="p">.</span><span class="n">lgamma</span><span class="p">(</span><span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
      <span class="o">-</span> <span class="n">torch</span><span class="p">.</span><span class="n">lgamma</span><span class="p">(</span><span class="n">k</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> 
      <span class="o">-</span> <span class="n">torch</span><span class="p">.</span><span class="n">lgamma</span><span class="p">(</span><span class="n">n</span> <span class="o">-</span> <span class="n">k</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
  <span class="p">)</span>
</code></pre></div></div>

<p>Let’s see that it works by printing \(\binom{5}{i}\) for \(i = 0, \dots, 5\):</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">6</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">log_binom_coef</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">).</span><span class="n">exp</span><span class="p">())</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([ 1.,  5., 10., 10.,  5.,  1.])
</code></pre></div></div>

<p>Appears just right. Now we can implement the Bernstein basis in a naive manner, by definition:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">bernstein_basis</span><span class="p">(</span><span class="n">degree</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
  <span class="s">"""
  Computes a matrix containing the Bernstein basis of a given degree, where
  each row corresponds to an entry in the input tensor `z`.
  """</span>

  <span class="c1"># entries of `z` in rows, and basis indices in columns  
</span>  <span class="n">z</span> <span class="o">=</span> <span class="n">z</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> 
  <span class="n">ks</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">degree</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">z</span><span class="p">.</span><span class="n">device</span><span class="p">).</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

  <span class="c1"># degree in a tensor to call log_binom_coef
</span>  <span class="n">degree_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">degree</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">z</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>

  <span class="c1"># now we compute the Bernstein basis by definition
</span>  <span class="n">binom_coef</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_binom_coef</span><span class="p">(</span><span class="n">degree_tensor</span><span class="p">,</span> <span class="n">ks</span><span class="p">))</span>
  <span class="k">return</span> <span class="n">binom_coef</span> <span class="o">*</span> <span class="p">(</span><span class="n">z</span> <span class="o">**</span> <span class="n">ks</span><span class="p">)</span> <span class="o">*</span> <span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">z</span><span class="p">)</span> <span class="o">**</span> <span class="p">(</span><span class="n">degree_tensor</span> <span class="o">-</span> <span class="n">ks</span><span class="p">))</span>
</code></pre></div></div>

<p>As stated above, this is not the most numerically ‘right’ way work with the Bernstein basis, and it would be more wise to use the well-known <a href="https://en.wikipedia.org/wiki/De_Casteljau%27s_algorithm">De Casteljau’s algorithm</a>, that is both efficient and numerically stable. In fact, in production-quality code that’s what we should do. Maybe even implement a custom CUDA kernel to make it efficient on the GPU. But I chose to avoid adding more complexity by introducing yet another algorithm, and keep this post as straightforward as possible.</p>

<p>So now that we have our ingredients in place, let’s implement a short PyTorch module implementing the idea in the nice diagram we saw above:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="k">class</span> <span class="nc">BernsteinPolynomialModel</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x_model</span><span class="p">,</span> <span class="n">coef_transformer</span><span class="p">):</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">coef_model</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">([</span>
        <span class="n">x_model</span><span class="p">,</span>
        <span class="n">coef_transformer</span>
    <span class="p">])</span>
  
  <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
    <span class="n">coefs</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">coef_model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="n">degree</span> <span class="o">=</span> <span class="n">coefs</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
    <span class="n">basis</span> <span class="o">=</span> <span class="n">bernstein_basis</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">degree</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">coefs</span> <span class="o">*</span> <span class="n">basis</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<h2 id="coefficient-transform-components">Coefficient transform components</h2>

<p>We already saw a simple transform that takes a vector, and converts it to a vector with non-decreasing components based on the ReLU function. We can do something similar with non-increasing functions using the negative of the ReLU function:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">NonIncreasingCoefTransform</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">u</span><span class="p">):</span>
        <span class="c1"># We assume that `u` has mini-batch dimensions, 
</span>        <span class="c1"># and the 'coefficient' dimension is the last one.
</span>        <span class="n">u_head</span> <span class="o">=</span> <span class="n">u</span><span class="p">[...</span> <span class="p">,</span><span class="mi">0</span><span class="p">:</span><span class="mi">1</span><span class="p">]</span>
        <span class="n">u_tail_relu</span> <span class="o">=</span> <span class="o">-</span><span class="n">nn</span><span class="p">.</span><span class="n">functional</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="n">u</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">:])</span>
        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">u_head</span><span class="p">,</span> <span class="n">torch</span><span class="p">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">u_tail_relu</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)])</span>
</code></pre></div></div>

<p>Let’s test it:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">NonincreasingCoefTransform</span><span class="p">()(</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mf">3.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="mf">3.</span><span class="p">])))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([ -5.0000,  -8.0486, -10.1755, -10.4888, -13.5374])
</code></pre></div></div>

<p>Appears to do what we wanted - produces a nonincreasing vector.  What if we’re modeling a CDF? Well, then we can add an additional <code class="language-plaintext highlighter-rouge">Sigmoid</code> layer on top of a <code class="language-plaintext highlighter-rouge">NonDecreasingCoefTransform</code>, that transforms our non-decreasing function whose output are arbitrary numbers, into a non-decreasing function whose output is in \([0, 1]\). Namely, we can use:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">([</span>
	<span class="n">NonDecreasingCoefTransform</span><span class="p">(),</span>
	<span class="n">nn</span><span class="p">.</span><span class="n">Sigmoid</span><span class="p">()</span>
<span class="p">])</span>
</code></pre></div></div>

<p>An interesting case is a CDF of a distribution whose support is <em>known</em> to be \([0, 1]\). Then we can model it directly with Bernstein polynomials whose coefficient vector \(\mathbf{a}\) satisfies:</p>

\[a_0 = 0 \leq a_1 \leq \dots \leq a_n = 1.\]

<p>To that end, we can use the SoftMax function with a cumulative sum. Assuming that \(\mathbf{u} \in \mathbb{R}^n\), we can define:</p>

\[a_i = \frac{\sum_{j=1}^i \exp(u_j)}{\sum_{j=1}^n \exp(u_j)}, \qquad 0 = 1, \dots, n\]

<p>Consequently, the corresponding layer is:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">CDFCoefTransform</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
  <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">u</span><span class="p">):</span>
    <span class="n">zero</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">u</span><span class="p">[...,</span> <span class="p">:</span><span class="mi">1</span><span class="p">])</span>
    <span class="n">cum_softmax</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">functional</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">cdf_coefs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">zero</span><span class="p">,</span> <span class="n">cum_softmax</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">cdf_coefs</span>
</code></pre></div></div>

<p>Let’s try it out:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">CDFCoefTransform</span><span class="p">()(</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mf">3.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="mf">3.</span><span class="p">])))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([0.0000e+00, 1.4056e-04, 4.1916e-01, 5.7331e-01, 5.8098e-01, 1.0000e+00])
</code></pre></div></div>

<p>Appears to do what we desire - a non-decreasing vector, going from 0 to 1. Now let’s try to use our components.</p>

<h1 id="example---learning-an-increasing-function">Example - learning an increasing function</h1>

<p>At first, I wanted to demonstrate it on an application from a domain I know - learning the CDF of auction bids in online advertising. But the data-sets, such as the IPinYou data-set, are too large to handle quickly enough for a blog post. We’ll be using NumPy to implement the synthetic function \(f(\mathbf{x}, z)\) we intend to fit, to make plotting and inspection straightforward. When we use it to generate a dataset, we shall transform the NumPy arrays into PyTorch tensors.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="k">def</span> <span class="nf">relu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
  <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">softshrink</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="o">=</span><span class="mf">0.3</span><span class="p">):</span>
  <span class="k">return</span> <span class="n">relu</span><span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">a</span><span class="p">)</span> <span class="o">-</span> <span class="n">relu</span><span class="p">(</span><span class="o">-</span><span class="n">x</span> <span class="o">-</span> <span class="n">a</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">sgn_square</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
  <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">hairy_increasing_func</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
  <span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">x3</span> <span class="o">=</span> <span class="n">x</span><span class="p">[...,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">x</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">x</span><span class="p">[...,</span> <span class="mi">2</span><span class="p">]</span>
  <span class="k">return</span> <span class="p">(</span><span class="n">relu</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">x1</span> <span class="o">-</span> <span class="n">x2</span> <span class="o">+</span> <span class="n">x3</span><span class="p">))</span> <span class="o">*</span> <span class="n">sgn_square</span><span class="p">(</span><span class="n">softshrink</span><span class="p">(</span><span class="n">z</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">x1</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
          <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">x2</span> <span class="o">+</span> <span class="n">x3</span><span class="p">))</span> <span class="o">*</span> <span class="n">sgn_square</span><span class="p">(</span><span class="n">softshrink</span><span class="p">(</span><span class="n">z</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">x2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
          <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">x1</span> <span class="o">-</span> <span class="n">x2</span><span class="p">))</span> <span class="o">*</span> <span class="n">sgn_square</span><span class="p">(</span><span class="n">softshrink</span><span class="p">(</span><span class="n">z</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">x3</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
          <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">x1</span> <span class="o">+</span> <span class="n">x2</span> <span class="o">+</span> <span class="n">x3</span><span class="p">))</span>
</code></pre></div></div>

<p>Indeed seems a bit ‘hairy’, so let’s inspect a few examples:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="n">zs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">zs</span><span class="p">,</span> <span class="n">hairy_increasing_func</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]),</span> <span class="n">zs</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">'function 1'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">zs</span><span class="p">,</span> <span class="n">hairy_increasing_func</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">]),</span> <span class="n">zs</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">'function 2'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">zs</span><span class="p">,</span> <span class="n">hairy_increasing_func</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mf">1.5</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">]),</span> <span class="n">zs</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">'function 3'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/hairy_increasing_functions.png" alt="hairy_increasing_functions" /></p>

<p>The function uses a few powers of ‘soft-shrink’ that generate ‘flat’ plateaus, to make fitting a bit challenging. The center and slope of these soft-shrink functions are based on trigonometric functions of the component of \(\mathbf{x}\). Powers of the soft-shrink function have a discontinuous derivative, and this shall make fitting a bit challenging, even with a small number of features. But it’s possible with polynomials of a high enough degree. As we saw in the polynomial features series - we are <em>not</em> afraid of fitting high-degree polynomial.</p>

<p>Using this function we can generate a PyTorch dataset. The function below generates a data-set of the specified size, and uploads it to the default CUDA GPU if it is available. This is to make our fitting experiments simple and fast when we have a GPU available:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">generate_dataset</span><span class="p">(</span><span class="n">n_rows</span><span class="p">,</span> <span class="n">noise</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
  <span class="n">xs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n_rows</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
  <span class="n">zs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="n">n_rows</span><span class="p">)</span>
  <span class="n">labels</span> <span class="o">=</span> <span class="n">hairy_increasing_func</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">zs</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n_rows</span><span class="p">)</span> <span class="o">*</span> <span class="n">noise</span>

  <span class="n">xs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">xs</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
  <span class="n">zs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">zs</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
  <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">labels</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
  <span class="k">if</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">is_available</span><span class="p">():</span>
    <span class="n">xs</span> <span class="o">=</span> <span class="n">xs</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>
    <span class="n">zs</span> <span class="o">=</span> <span class="n">zs</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>
    <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>

  <span class="k">return</span> <span class="n">xs</span><span class="p">,</span> <span class="n">zs</span><span class="p">,</span> <span class="n">labels</span>
</code></pre></div></div>

<p>Our next ingredient is a function that builds a PyTorch model. We will be comparing a monotonic model using our <code class="language-plaintext highlighter-rouge">BernsteinPolynomialModel</code> class we just implemented, to a regular fully-connected \(\mathrm{ReLU}\) network. So here is a function to create a model given layer dimensions that suppotrs both cases:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">make_model</span><span class="p">(</span><span class="n">layer_dims</span><span class="p">,</span> <span class="n">monotone</span><span class="o">=</span><span class="bp">True</span><span class="p">):</span>
  <span class="c1"># create a fully connected ReLU network
</span>  <span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
      <span class="n">layer</span>
      <span class="k">for</span> <span class="n">in_dim</span><span class="p">,</span> <span class="n">out_dim</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_dims</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_dims</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
      <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="p">[</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_dim</span><span class="p">,</span> <span class="n">out_dim</span><span class="p">),</span> <span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">()]</span>
  <span class="p">]</span>

  <span class="k">if</span> <span class="n">monotone</span><span class="p">:</span>
    <span class="c1"># define a model for x - a ReLU network whose last layer is linear
</span>    <span class="n">x_model</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>

    <span class="c1"># construct a network for predicting non-decreasing functions
</span>    <span class="c1"># the polynomial degree is the output dimension of the last
</span>    <span class="c1"># layer.
</span>    <span class="k">return</span> <span class="n">BernsteinPolynomialModel</span><span class="p">(</span>
        <span class="n">x_model</span><span class="p">,</span>
        <span class="n">NondecreasingCoefTransform</span><span class="p">()</span>
    <span class="p">)</span>
  <span class="k">else</span><span class="p">:</span>
    <span class="c1"># define a simple ReLU network - just add a linear layer
</span>    <span class="c1"># with one output on top of the ReLU network
</span>    <span class="n">layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">layer_dims</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span>
</code></pre></div></div>

<p>Let’s verify that even without training, our ‘monotone’ model indeed produces non-decreasing functions of \(z\) for each \(\mathbf{x}\).</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">2024</span><span class="p">)</span>  <span class="c1"># just to make this result reproducible
</span><span class="n">net</span> <span class="o">=</span> <span class="n">make_model</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">])</span>

<span class="n">plot_zs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>

<span class="n">func</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">30.</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">10</span><span class="p">]).</span><span class="n">repeat</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">func</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">'Input = [30, 20, 10]'</span><span class="p">)</span>

<span class="n">func</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">10.</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">30</span><span class="p">]).</span><span class="n">repeat</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">func</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">'Input = [10, 20, 30]'</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/increasing_random_model.png" alt="increasing_random_model" /></p>

<p>Well, indeed the model appears to generate increasing functions of z.</p>

<p>Now to our last ingredient - model training. Here is a pretty-standard PyTorch training loop, but with a small customization to support monotonic models accepting the features as two parameters \(\mathbf{x}, z\), and ‘regular’ models accepting only one features parameter:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">tqdm.auto</span> <span class="kn">import</span> <span class="n">tqdm</span>

<span class="k">def</span> <span class="nf">train_epoch</span><span class="p">(</span><span class="n">data_iter</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">optim</span><span class="p">,</span> <span class="n">monotone</span><span class="p">):</span>
  <span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">data_iter</span><span class="p">:</span>
    <span class="k">if</span> <span class="n">monotone</span><span class="p">:</span>
      <span class="n">pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">z</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
      <span class="n">pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">z</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)).</span><span class="n">squeeze</span><span class="p">()</span>
    <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span>

    <span class="n">optim</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
    <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
    <span class="n">optim</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
</code></pre></div></div>

<p>And here is a pretty-standard evaluation loop, doing the same:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">valid_epoch</span><span class="p">(</span><span class="n">data_iter</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">monotone</span><span class="p">):</span>
  <span class="n">epoch_loss</span> <span class="o">=</span> <span class="mf">0.</span>
  <span class="n">num_samples</span> <span class="o">=</span> <span class="mi">0</span>
  <span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">data_iter</span><span class="p">:</span>
    <span class="k">if</span> <span class="n">monotone</span><span class="p">:</span>
      <span class="n">pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">z</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
      <span class="n">pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">z</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)).</span><span class="n">squeeze</span><span class="p">()</span>
    <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span>
    <span class="n">epoch_loss</span> <span class="o">+=</span> <span class="n">loss</span> <span class="o">*</span> <span class="n">label</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">num_samples</span> <span class="o">+=</span> <span class="n">label</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  <span class="k">return</span> <span class="n">epoch_loss</span><span class="p">.</span><span class="n">cpu</span><span class="p">().</span><span class="n">item</span><span class="p">()</span> <span class="o">/</span> <span class="n">num_samples</span>
</code></pre></div></div>

<p>Now let’s integrate all ingredients into one function that creates a model and an optimizer, and runs several train+evaluation epochs using the mean squared error loss:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">train_model</span><span class="p">(</span><span class="n">train_iter</span><span class="p">,</span> <span class="n">valid_iter</span><span class="p">,</span> <span class="n">layer_dims</span><span class="p">,</span> <span class="n">monotone</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
                <span class="n">optim_fn</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">SGD</span><span class="p">,</span> <span class="n">optim_params</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">num_epochs</span><span class="o">=</span><span class="mi">100</span><span class="p">):</span>
  <span class="k">if</span> <span class="n">optim_params</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
    <span class="n">optim_params</span> <span class="o">=</span> <span class="p">{}</span>

  <span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">2024</span><span class="p">)</span>
  <span class="n">model</span> <span class="o">=</span> <span class="n">make_model</span><span class="p">(</span><span class="n">layer_dims</span><span class="p">,</span> <span class="n">monotone</span><span class="o">=</span><span class="n">monotone</span><span class="p">)</span>
  <span class="n">optim</span> <span class="o">=</span> <span class="n">optim_fn</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="o">**</span><span class="n">optim_params</span><span class="p">)</span>
  <span class="n">loss_fn</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">MSELoss</span><span class="p">()</span>

  <span class="k">if</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">is_available</span><span class="p">():</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>

  <span class="k">with</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">num_epochs</span><span class="p">))</span> <span class="k">as</span> <span class="n">epoch_range</span><span class="p">:</span>
    <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">epoch_range</span><span class="p">:</span>
      <span class="n">train_epoch</span><span class="p">(</span><span class="n">train_iter</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">optim</span><span class="p">,</span> <span class="n">monotone</span><span class="p">)</span>
      <span class="n">epoch_loss</span> <span class="o">=</span> <span class="n">valid_epoch</span><span class="p">(</span><span class="n">valid_iter</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">monotone</span><span class="p">)</span>
      <span class="n">epoch_range</span><span class="p">.</span><span class="n">set_description</span><span class="p">(</span><span class="sa">f</span><span class="s">'Validation loss = </span><span class="si">{</span><span class="n">epoch_loss</span><span class="si">:</span><span class="p">.</span><span class="mi">5</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
  <span class="k">return</span> <span class="n">model</span><span class="p">,</span> <span class="n">epoch_loss</span>
</code></pre></div></div>

<p>Now let’s train! First, we create the train and evaluation datasets:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">batch_iter</span> <span class="kn">import</span> <span class="n">BatchIter</span>

<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
<span class="n">train_iter</span> <span class="o">=</span> <span class="n">BatchIter</span><span class="p">(</span><span class="o">*</span><span class="n">generate_dataset</span><span class="p">(</span><span class="mi">50000</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
<span class="n">valid_iter</span> <span class="o">=</span> <span class="n">BatchIter</span><span class="p">(</span><span class="o">*</span><span class="n">generate_dataset</span><span class="p">(</span><span class="mi">10000</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
</code></pre></div></div>

<p>Now we train a monotonic model. I chose its architecture, the optimizer, and its parameters using hyperparameter tuning with the validation set. But to make this post straightforward, I’m just writing the final hyper-parameters I selected:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">lr</span> <span class="o">=</span> <span class="mf">3e-3</span>
<span class="n">weight_decay</span> <span class="o">=</span> <span class="mf">1e-5</span>
<span class="n">degree</span> <span class="o">=</span> <span class="mi">50</span>
<span class="n">layer_dims</span> <span class="o">=</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span>
              <span class="mi">4</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="mi">3</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="mi">2</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="n">degree</span><span class="p">]</span>
<span class="n">model</span><span class="p">,</span> <span class="n">val_loss</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span>
    <span class="n">train_iter</span><span class="p">,</span> <span class="n">valid_iter</span><span class="p">,</span> <span class="n">layer_dims</span><span class="p">,</span>
    <span class="n">optim_fn</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">AdamW</span><span class="p">,</span>
    <span class="n">optim_params</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">))</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">cpu</span><span class="p">()</span>
</code></pre></div></div>

<p>I got a validation loss of <code class="language-plaintext highlighter-rouge">0.0127</code>. Now let’s plot some functions the model learned, and see how they compare to the “true” hairy function we designed. Here is code to produce the function for \(\mathbf{x} = (1, 0.5, -0.5)\):</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">features</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">]).</span><span class="n">repeat</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">func</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">features</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">func</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">'Model function'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">hairy_increasing_func</span><span class="p">(</span><span class="n">features</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">plot_zs</span><span class="p">.</span><span class="n">numpy</span><span class="p">()),</span> <span class="n">label</span><span class="o">=</span><span class="s">'True function'</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/hairy_bernstein_fit_1.png" alt="hairy_bernstein_fit_1" /></p>

<p>Seems pretty close. Let’s try another one with \(\mathbf{x} = (-1.5, 0.8, 0.1)\):</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">features</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mf">1.5</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">]).</span><span class="n">repeat</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">func</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">features</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">func</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">'Model function'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">hairy_increasing_func</span><span class="p">(</span><span class="n">features</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">plot_zs</span><span class="p">.</span><span class="n">numpy</span><span class="p">()),</span> <span class="n">label</span><span class="o">=</span><span class="s">'True function'</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/hairy_bernstein_fit_2.png" alt="hairy_bernstein_fit_2" /></p>

<p>A bit farther away, but not very bad.</p>

<p>Now let’s try training a regular ReLU network on the same dataset and see what functions we have. Its architecture is going to be similar to the \(\mathbf{x}\) network from the monotonic example above, but its input dimension is going to be four, instead of three features. This is because now \(z\) is not handled separately from the other features. So here is the code to train the network:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">lr</span> <span class="o">=</span> <span class="mf">3e-3</span>
<span class="n">weight_decay</span> <span class="o">=</span> <span class="mf">1e-5</span>
<span class="n">degree</span> <span class="o">=</span> <span class="mi">50</span> <span class="c1"># there is no "degree" - it's here just to preserve model architecture.
</span><span class="n">layer_dims</span> <span class="o">=</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span>
              <span class="mi">4</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="mi">3</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="mi">2</span> <span class="o">*</span> <span class="n">degree</span><span class="p">,</span>
              <span class="n">degree</span><span class="p">]</span>
<span class="n">model</span><span class="p">,</span> <span class="n">val_loss</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span>
    <span class="n">train_iter</span><span class="p">,</span> <span class="n">valid_iter</span><span class="p">,</span> <span class="n">layer_dims</span><span class="p">,</span> <span class="n">monotone</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
    <span class="n">optim_fn</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">AdamW</span><span class="p">,</span>
    <span class="n">optim_params</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">))</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">cpu</span><span class="p">()</span>
</code></pre></div></div>

<p>I got a validation loss of \(0.01404\) - slightly worse, but no by much. Let’s see what functions we’re getting for the same two vectors \(\mathbf{x}\) we tried before. So here is the code for \(\mathbf{x} =(1, 0.5, -0.5)\):</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">features</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">]).</span><span class="n">repeat</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
    <span class="n">plot_zs</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">model</span><span class="p">(</span><span class="n">features</span><span class="p">).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">'Model function'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">hairy_increasing_func</span><span class="p">(</span><span class="n">features</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">plot_zs</span><span class="p">.</span><span class="n">numpy</span><span class="p">()),</span> <span class="n">label</span><span class="o">=</span><span class="s">'True function'</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/hairy_relu_fit_1.png" alt="hairy_relu_fit_1" /></p>

<p>The model function appears monotonic. Is this a coincidence? Well, let’s try our second vector \(\mathbf{x} = (-1.5, 0.8, 0.1)\):</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">features</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mf">1.5</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">]).</span><span class="n">repeat</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
    <span class="n">plot_zs</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">model</span><span class="p">(</span><span class="n">features</span><span class="p">).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">'Model function'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">plot_zs</span><span class="p">,</span> <span class="n">hairy_increasing_func</span><span class="p">(</span><span class="n">features</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">plot_zs</span><span class="p">.</span><span class="n">numpy</span><span class="p">()),</span> <span class="n">label</span><span class="o">=</span><span class="s">'True function'</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://alexshtf.github.io/assets/hairy_relu_fit_2.png" alt="hairy_relu_fit_2" /></p>

<p>This one isn’t! If we think about it - there is a good reason. Our synthetic dataset was generated by random sampling of standard normal variables - vectors with features are close to zero are more common than those with features farther away. The vector \(\mathbf{x} = (1, 0.5, -0.5)\) has components closer to zero than \(\mathbf{x} = (-1.5, 0.8, 0.1)\), so there was more training data similar to the former vector than to the latter. Consequently, the model could learn better to represent the functions in the neighbourhood of the former vector. However, when a model is monotone <em>by design</em>, we don’t rely on having enough data for the model to discover monotonic behavior. It’s built into the model.</p>

<h1 id="summary-and-discussion">Summary and discussion</h1>

<p>In this post we saw an interesting combination of neural networks with Bernstein polynomials that allow learning shape constraints. This is useful when the shape constraint is actually a <em>constraint</em>, i.e. required for the predictions of the model to be correct from a mathematical or business perspective. Moreover, it’s a form of regularization, since that’s what regularization often is - injecting <em>prior knowledge</em> about the hypothesis class into the fitting procedure.</p>

<p>The idea of constraining coefficients of a function in a given basis to constrain its shape works not only for Bernstein polynomials, but also for the <a href="https://en.wikipedia.org/wiki/B-spline">B-Spline basis</a><sup id="fnref:8" role="doc-noteref"><a href="#fn:8" class="footnote" rel="footnote">6</a></sup>. Probably also for a variety of other ‘shape-preserving’ bases that I never heard about. So you’re welcome to try this idea with those bases as well, if you believe they suit your needs.</p>

<p>An interesting variation could be designing a polynomial that is monotonic, non-negative, convex or concave over the entire real line \((-\infty, \infty)\). There is an interesting theorem that dates back to Hilbert’s 1888 paper<sup id="fnref:6" role="doc-noteref"><a href="#fn:6" class="footnote" rel="footnote">7</a></sup>, that any polynomial \(p(z)\) of degree \(2d\) is non-negative over the entire real line if and only if it is a sum of squares of polynomials. Alternatively, this can be phrased as the existance of a positive-semidefinite matrix \(\mathbf{P} \in \mathbb{R}^{d \times d}\) such that the polynomial can be written as</p>

\[p(z;\mathbf{P}) = \begin{pmatrix}1 &amp; z &amp; \dots &amp; z^d\end{pmatrix} \mathbf{P} \begin{pmatrix}1 \\ z \\ \vdots \\ z^d \end{pmatrix}.\]

<p>Any positive semidefinite matrix \(\mathbf{P}\) can be decomposed as \(\mathbf{P} = \mathbf{V} \mathbf{V}^T\). So just like we predicted the Bernstein coefficient vector \(\mathbf{a}\) based on the features \(\mathbf{x}\), we could alternatively build a model that learns to predict \(\mathbf{V}\).</p>

<p>Since a polynomial is increasing if and only if its derivative is non-negative, we can just take an integral of a non-negative polynomial. Similarly, Convexity can be represented using double-integration of a non-negative polynomial, since a polynomial is convex if and only if its second derivative is non-negative. In boh cases, it’s just multiplying the matrix \(\mathbf{P}\) by the corresponding constant representing integration or double integration. Similar “sum of squares” techniques can be used to construct polynomials over an interval, by integrating non-negative polynomials over an interval. See Blekherman et. al. <sup id="fnref:7" role="doc-noteref"><a href="#fn:7" class="footnote" rel="footnote">8</a></sup>, Theorem 3.72.</p>

<p>Now let’s get back to the realm of Bernstein polynomials. What happens if we want a polynomial that is <em>both</em> convex and increasing? Or both concave and increasing? This seems useful as well, if we would like to model a utility function that represents diminishing returns. But in this case, we need to impose <em>two</em> constraints on the coefficient vector of the polynomial: one for monotonicity, and another one for concavity. This appears easy with convex optimization solvers that support constraints out of the box, but harder to achieve if we want to train a neural network with PyTorch that produces an coefficient vector that satisfies several constraints. This is exactly what we shall explore in the next post!</p>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:1" role="doc-endnote">
      <p>You, S., Ding, D., Canini, K., Pfeifer, J., &amp; Gupta, M. (2017). Deep lattice networks and partial monotonic functions. <em>Advances in neural information processing systems</em>, <em>30</em>. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>Hastie, T., &amp; Tibshirani, R. (1993). Varying-coefficient models. Journal of the Royal Statistical Society Series B: Statistical Methodology, 55(4), 757-779. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:3" role="doc-endnote">
      <p>Ghosal, R., Ghosh, S., Urbanek, J.,  Schrack, J. A., &amp; Zipunnikov, V. (2023). Shape-constrained  estimation in functional regression with Bernstein polynomials. <em>Computational Statistics &amp; Data Analysis</em>, <em>178</em>, 107614. <a href="#fnref:3" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:4" role="doc-endnote">
      <p>Hoover, D. R., Rice, J. A., Wu, C. O., &amp; Yang, L. P. (1998). Nonparametric smoothing estimates of  time-varying coefficient models with longitudinal data. <em>Biometrika</em>, <em>85</em>(4), 809-822. <a href="#fnref:4" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:5" role="doc-endnote">
      <p>Huang, J. Z., Wu, C. O., &amp; Zhou,  L. (2004). Polynomial spline estimation and inference for varying  coefficient models with longitudinal data. <em>Statistica Sinica</em>, 763-788. <a href="#fnref:5" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:8" role="doc-endnote">
      <p>Carl De-Boor. <em>A practical guide to splines</em>. (1993) <a href="#fnref:8" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:6" role="doc-endnote">
      <p>Hilbert, D. (1888). Über die darstellung definiter formen als summe von formenquadraten. <em>Mathematische Annalen</em>, <em>32</em>(3), 342-350. <a href="#fnref:6" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:7" role="doc-endnote">
      <p>Grigoriy Blekherman, Pablo A. Parrilo, and Rekha R. Thomas. <em>Semidefinite Optimization and Convex Algebraic Geometry</em>. SIAM (2012) <a href="#fnref:7" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Alex Shtoff</name><email>alex.shtf@gmail.com</email></author><category term="pytorch" /><category term="machine-learning" /><category term="monotonic-regression" /><category term="bernstein" /><category term="polynomial-regression" /><summary type="html"><![CDATA[Build shape-restricted models in PyTorch by predicting Bernstein polynomial coefficients: monotone and bounded functions of a chosen feature, with flexible dependence on the rest.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="https://alexshtf.github.io/assets/increasing_function_model.png" /><media:content medium="image" url="https://alexshtf.github.io/assets/increasing_function_model.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry></feed>