{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# UoI-NMF for robust parts-based decomposition of noisy data\n\nThis example will demonstrate parts-based decomposition with\nUoI-NMF on the swimmer dataset.\nThe swimmer dataset is the canonical example of separable data.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Swimmer dataset\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn.preprocessing import minmax_scale\nfrom sklearn.manifold import TSNE\n\nfrom pyuoi.decomposition import UoI_NMF\nfrom pyuoi.datasets import load_swimmer\n\n\nmatplotlib.rcParams['figure.figsize'] = [4, 4]\nnp.random.seed(10)\n\nswimmers = load_swimmer()\nswimmers = minmax_scale(swimmers, axis=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Original Swimmer samples\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots(4, 4, subplot_kw={'xticks': [], 'yticks': []})\nindices = np.random.randint(16, size=16) + np.arange(0, 256, 16)\nax = ax.flatten()\nfor i in range(len(indices)):\n    ax[i].imshow(swimmers[indices[i]].reshape(32, 32).T,\n                 aspect='auto', cmap='gray')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Swimmer samples corrupted with Absolute Gaussian noise\n\nCorrupt the images with with absolute Gaussian noise with ``std = 0.25``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "reps = 1\nn_swim = swimmers.shape[0]\ncorrupted = np.zeros((n_swim * reps, swimmers.shape[1]))\nfor r in range(reps):\n    noise = np.abs(np.random.normal(scale=0.25, size=swimmers.shape))\n    corrupted[r * n_swim:(r + 1) * n_swim] = swimmers + noise\n\nfig, ax = plt.subplots(4, 4, subplot_kw={'xticks': [], 'yticks': []})\nax = ax.flatten()\nfor i in range(len(indices)):\n    ax[i].imshow(corrupted[indices[i]].reshape(32, 32).T,\n                 aspect='auto', cmap='gray')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Run UoI NMF on corrupted Swimmer data\n\nTwenty bootstraps should be enough.\n``min_pts`` should be half of the number of bootstraps.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "nboot = 20\nmin_pts = nboot / 2\nranks = [16]\n\nshape = corrupted.shape\n\nuoi_nmf = UoI_NMF(n_boots=nboot, ranks=ranks, db_min_samples=min_pts,\n                  nmf_max_iter=800)\n\ntransformed = uoi_nmf.fit_transform(corrupted)\nrecovered = transformed @ uoi_nmf.components_"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NMF Swimmer bases\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "order = np.argsort(np.sum(uoi_nmf.components_, axis=1))\n\nfig, ax = plt.subplots(4, 4, subplot_kw={'xticks': [], 'yticks': []})\nax = ax.flatten()\nfor i in range(uoi_nmf.components_.shape[0]):\n    ax[i].imshow(uoi_nmf.components_[order[i]].reshape(32, 32).T,\n                 aspect='auto', cmap='gray')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Recovered Swimmers\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots(4, 4, subplot_kw={'xticks': [], 'yticks': []})\nax = ax.flatten()\nfor i in range(len(indices)):\n    ax[i].imshow(recovered[indices[i]].reshape(32, 32).T,\n                 aspect='auto', cmap='gray')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot them all together so we can see how well we recovered\nthe original swimmer data.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots(3, 16, figsize=(27, 5),\n                       subplot_kw={'xticks': [], 'yticks': []})\nindices = np.random.randint(16, size=16) + np.arange(0, 256, 16)\nax = ax.flatten()\n\n# plot Original\nax[0].set_ylabel('Original', rotation=0, fontsize=25, labelpad=40)\nax[0].yaxis.set_label_coords(-1.0, 0.5)\nfor i in range(len(indices)):\n    ax[i].imshow(swimmers[indices[i]].reshape(32, 32).T,\n                 aspect='auto', cmap='gray')\n\n# plot Corrupted\nax[16].set_ylabel('Corrupted', rotation=0, fontsize=25, labelpad=40)\nax[16].yaxis.set_label_coords(-1.1, 0.5)\nfor i in range(len(indices)):\n    ax[16 + i].imshow(corrupted[indices[i]].reshape(32, 32).T,\n                      aspect='auto', cmap='gray')\n\n# plot Recovered\nax[32].set_ylabel('Recovered', rotation=0, fontsize=25, labelpad=40)\nax[32].yaxis.set_label_coords(-1.1, 0.5)\nfor i in range(len(indices)):\n    ax[32 + i].imshow(recovered[indices[i]].reshape(32, 32).T,\n                      aspect='auto', cmap='gray')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To see what DBSCAN is doing, let's look at the bases samples.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "plt.figure()\nembedding = TSNE(n_components=2).fit_transform(uoi_nmf.bases_samples_)\nsc = plt.scatter(embedding[:, 0], embedding[:, 1],\n                 c=uoi_nmf.bases_samples_labels_, s=80, cmap=\"nipy_spectral\")\nsc.set_facecolor('none')\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}