{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Granley et al. (2021): Effects of Biphasic Pulse Parameters with the BiphasicAxonMapModel\n\nThis example shows how to use the\n:py:class:`~pulse2percept.models.BiphasicAxonMapModel` to model the effects of \nbiphasic pulse train parameters phosphene appearance in an epiretinal\nimplant such as :py:class:`~pulse2percept.implants.ArgusII`. \n\nBiphasic pulse trains are a commonly used type of stimulus in visual prostheses. \nThis model enhances the :py:class:`~pulse2percept.models.AxonMapModel` to reflect\nthe effects of the amplitude, frequency, and pulse duration on threshold,\nphosphene size, brightness, and streak length, according to previous\npsychophysical and electrophysiological studies.\n\nThe :py:class:`~pulse2percept.models.BiphasicAxonMapModel` shares the same underlying \nassumptions as the axon map model. Namely, an axon's sensitivity to electrical stimulation\nis assumed to decay exponentially with...\n\n*  distance along the axon from the soma ($d_s$), with spatial decay\n   constant $\\lambda$,\n*  distance from the stimulated electrode ($d_e$), with spatial decay \n   constant $\\rho$.\n\nIn the biphasic model, the radial decay rate $\\rho$ is scaled by $F_{size}$,\nthe axonal decay rate $\\lambda$ is scaled by $F_{streak}$, and the brightness \ncontribution from each electrode is scaled by $F_{bright}$. These 3 equations are called\neffect models. The final equation for the brightness intensity for a pixel located at polar \ncoordinates $(r, \\theta)$ is given by:\n\n\\begin{align}I =  \\max_{axon}\\sum_{elecs}F_\\mathrm{bright} \\exp\\left(\\frac{-d_{e}^2}{2\\rho^2 F_\\mathrm{size} } + \n            \\frac{-d_{s}^2}{2\\lambda^2 F_\\mathrm{streak} }\\right).\\end{align}\n\n\n## Basic Model Usage\nThe biphasic axon map model can be instantiated and ran similarly to other models,\nwith the exception that all stimuli are required to be :py:class:`~pulse2percept.stimuli.BiphasicPulseTrain`\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nimport numpy as np\nfrom pulse2percept.implants import ArgusII\nfrom pulse2percept.models import BiphasicAxonMapModel\nfrom pulse2percept.stimuli import BiphasicPulseTrain\nmodel = BiphasicAxonMapModel(rho=200, axlambda=800)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Parameters you don't specify will take on default values. You can inspect\nall current model parameters as follows:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The most important parameters are ``rho`` and ``axlambda``, which control the \nradial and axonal current spread, respectively. The parameters ``a0``-``a9`` are \ncoefficients for the size, streak, and bright models, which will be discussed \nlater in this example. The biphasic axon map model supports both the default \ncython engine and a faster, gpu-enabled jax engine.\n\nThe rest of the parameters are shared with \n:py:class:`~pulse2percept.models.AxonMapModel`. For full details on these \nparameters, see the Axon Map Tutorial\n\n\nNext, build the model to perform expensive, one time calculations,\nand specify a visual prosthesis from the\n:py:mod:`~pulse2percept.implants` module. Models with an axon map are well \nsuited for epiretinal implants, such as Argus II.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model.build()\nimplant = ArgusII()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        ".. important ::\n\n    You need to build a model only once. After that, you can apply any number\n    of stimuli -- or even apply the model to different implants -- without\n    having to rebuild (which takes time).\n\n    However, if you change model parameters\n    (e.g., by directly setting ``model.a5 = 2``), you will have to\n    call ``model.build()`` again for your changes to take effect.\n\n\nYou can visualize the location of the implant and the axon map\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model.plot()\nimplant.plot()\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As mentioned above, the Biphasic Axon Map Model only accepts \n:py:class:`~pulse2percept.stimuli.BiphasicPulseTrain`\nstimuli with no :py:attr:`~pulse2percept.stimuli.BiphasicPulseTrain.delay_dur`. \nThe amplitude given to the BiphasicPulseTrain\nis interpreted as amplitude as a factor of threshold (i.e. an amp of 1 means \n1xTh)\n\nYou can easily assign BiphasicPulseTrains to electrodes with a dictionary\nThe following creates a train with 20Hz frequency, 1xTh amplitude, and 0.45ms\npulse / phase duration.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "implant.stim = {'A4' : BiphasicPulseTrain(20, 1, 0.45)}\nimplant.stim.plot()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Finally, you can predict the percept resulting from stimulation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "percept = model.predict_percept(implant)\nax = percept.plot()\nax.set_title('Predicted percept')\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Increasing the frequency will make phosphenes brighter\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)\nimplant.stim = {'A4' : BiphasicPulseTrain(50, 1, 0.45)}\nnew_percept = model.predict_percept(implant)\nnew_percept.plot(ax=axes[1])\npercept.plot(ax=axes[0], vmax=new_percept.max())\naxes[0].set_title(\"20 Hz\")\naxes[1].set_title(\"40 Hz\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note that without setting vmax, matplotlib automatically rescales images to\nhave the same max brightness and the difference isn't visible\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Increasing amplitude increases both size and brightness\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)\nimplant.stim = {'A4' : BiphasicPulseTrain(20, 3, 0.45)}\nnew_percept = model.predict_percept(implant)\nnew_percept.plot(ax=axes[1])\npercept.plot(ax=axes[0], vmax=new_percept.max())\naxes[0].set_title(\"1xTh\")\naxes[1].set_title(\"3xTh\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Increasing pulse duration decreases threshold, thus indirectly causing an \nincrease in size and brightness (amp factor is increased)\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)\nimplant.stim = {'A4' : BiphasicPulseTrain(20, 1, 4)}\nnew_percept = model.predict_percept(implant)\nnew_percept.plot(ax=axes[1])\npercept.plot(ax=axes[0], vmax=new_percept.max())\naxes[0].set_title(\"0.45ms\")\naxes[1].set_title(\"4ms\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If you account for the change in threshold by decreasing amplitude, then \nthe only affect of increasing pulse duration is the streak length decreasing\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)\nimplant.stim = {'A4' : BiphasicPulseTrain(20, 0.023835, 20)}\nnew_percept = model.predict_percept(implant)\nnew_percept.plot(ax=axes[1])\npercept.plot(ax=axes[0], vmax=new_percept.max())\naxes[0].set_title(\"0.45ms\")\naxes[1].set_title(\"20ms, 0.02xTh\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This illustrates another important point: The amplitude used for the Biphasic\nmodel is relative to the threshold current at 0.45ms pulse duration. Since larger \npulse durations have been shown to reduce the threshold amplitude needed, the \n0.02xTh amplitude used in the previous plot still is able to produce a phosphene.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Changing Effect Models\nAll of the 'effects' plotted above (e.g. size increasing with amplitude)\nare controlled by the effect models $F_{bright}$, $F_{size}$, and\n$F_{streak}$. The variables \n``bright_model``, ``size_model``, and ``streak_model`` encode the \neffects models.\n\nThese default to :py:class:`~pulse2percept.models.granley2021.DefaultBrightModel`,\n:py:class:`~pulse2percept.models.granley2021.DefaultSizeModel`, and \n:py:class:`~pulse2percept.models.granley2021.DefaultStreakModel` respectively, which\nimplement the simple scaling functions described in [Granley et al. (2021)]([Granley2021]).\n\n\nThe coefficients ``a0``-``a9`` parametrize these effect models. While the default values\nare likely to work for most cases, they can be customized to be patient specific. \nNotice how we only have to change the value given to the `BiphasicAxonMapModel`, \nand it is automatically passed down to the effect models.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model.a5 = 0\nprint(model.size_model.a5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "For example, ``a0`` and ``a1`` control how threshold changed with pulse duration: \n$amp = (A_0*pdur + A_1)^{-1}*amp$. Thus, pulse duration threshold \nscaling can easily be disabled by setting ``a0`` to 0 and ``a1`` to 1. If we increase \npulse duration like we did previously, we will now see that only streak length decreases, \nand we no longer have to change amplitude to account for change in threshold\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = BiphasicAxonMapModel(rho=200, axlambda=800)\nmodel.a0 = 0\nmodel.a1 = 1\nmodel.build()\nfig, axes = plt.subplots(1, 2, sharex=True, sharey=True)\nimplant.stim = {'A4' : BiphasicPulseTrain(20, 1, 0.45)}\npercept = model.predict_percept(implant)\nimplant.stim = {'A4' : BiphasicPulseTrain(20, 1, 20)}\nnew_percept = model.predict_percept(implant)\nnew_percept.plot(ax=axes[1])\npercept.plot(ax=axes[0], vmax=new_percept.max())\naxes[0].set_title(\"0.45ms\")\naxes[1].set_title(\"20ms\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Similarly, ``a2``-``a4`` control brightness scaling; ``a5``-``a6`` control size scaling, and\n``a7``-``a9`` control streak length scaling. For more details on these parameters,\nsee the effect models documentation, or [Granley2021]_ \n\n## Advanced Usage\n\n### Custom Effect Models\nFor most cases, using the provided, default implementation of the effect models\nwill probably be enough. However, the effect models are completely modular, and \ncan be replaced by any python callable with the parameters frequency, amplitude, \nand pulse duration. For example, we can easily change the model to no longer scale size\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = BiphasicAxonMapModel(rho=200, axlambda=800)\ndef size_modulation(freq, amp, pdur):\n    return 1\nmodel.size_model = size_modulation\nmodel.build()\n\nfig, axes = plt.subplots(1, 2, sharex=True, sharey=True)\nimplant.stim = {'A4' : BiphasicPulseTrain(20, 1, 0.45)}\npercept = model.predict_percept(implant)\nimplant.stim = {'A4' : BiphasicPulseTrain(20, 3, 0.45)}\nnew_percept = model.predict_percept(implant)\nnew_percept.plot(ax=axes[1])\npercept.plot(ax=axes[0], vmax=new_percept.max())\naxes[0].set_title(\"1xTh\")\naxes[1].set_title(\"3xTh\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The stimuli with larger amplitude created a brighter, but equally-sized phosphene\n\n\nThe effect models can even be a class, and can have its own parameters, \nwhich can be shared with the overarching BiphasicAxonMapModel itself (e.g. an effect \nmodel can depend on ``rho``, and if ``model.rho`` is changed, ``rho`` will also be changed in\nthe effect model). For an example of this, \nsee :py:class:`~pulse2percept.models.granley2021.DefaultSizeModel` \n\n\nIf using custom effect models with jax, the effect models must be written for jax so they can\nbe JIT compiled (i.e. using jax.numpy instead of numpy)\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### JAX Engine\n\nThe default computational engine is cython, but an engine based on \n[jax](https://github.com/google/jax) is also provided. The jax engine is slightly faster on CPU\nand significantly faster on GPU, at the cost of increased memory usage. The jax-based model \ncan be used identically to the cython engine, but it also has some additional features\nand limitations. \n\n.. note ::\n\n    Jax functions are compiled the first time they are called. Thus, the first\n    `predict_percept` will be slow. Subsequent calls reuse the compiled and\n    optimized function, and are much faster\n\nOne additional feature is the \n`_predict_spatial_jax` function,\nwhich is a stripped, purely functional version of \n`predict_percept` that operates on\nnumpy arrays. This avoids the overhead of creating p2p stimulus and percept objects,\nand if used correctly, provides an additional speedup. \n\n`_predict_spatial_jax` takes in \na (n_elecs, 3) numpy array specifying the frequency, amplitude, and pulse duration on\neach electrode, and two (n_elec) shaped arrays specifying the x and y locations of each\nelectrode\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = BiphasicAxonMapModel(engine='jax')\nmodel.build()\nimplant = ArgusII()\nex = np.array([implant[e].x for e in implant.electrodes])\ney = np.array([implant[e].y for e in implant.electrodes]) \nstim = np.zeros((60, 3))\nstim[3] = [20, 1, 0.45]\npercept = model._predict_spatial_jax(stim, ex, ey)\npercept = np.array(percept).reshape(model.grid.shape)\nplt.imshow(percept, cmap='gray')\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "One other useful feature is the \n`predict_percept_batched` function. This\napplies predict_percept to batches of input stimuli, using optimized matrix operations. See also \nits faster, stripped version `_predict_spatial_batched`. This \nfunction is only intended to be used if you are repeatedly simulating batches of percepts. \nSince jax compiles each function the first time it is used, using this function only once\nfor a singular group of stimuli will be noticably slower than repeatedly applying \n`predict_percept`. However, splitting a very large set of stimuli into smaller batches and \nusing `predict_percept_batched` will be significantly faster than `predict_percept` on each\nindividual stimuli.\n\nNote that this function consumes a large amount of memory, and may not run on systems or \nGPUs with limited memory. \n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}