{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Generating a stimulus from an image\n\n*This example shows how to use images as input stimuli for a retinal implant.*\n\nIn addition to built-in stimuli such as\n:py:class:`~pulse2percept.stimuli.BiphasicPulse` and\n:py:class:`~pulse2percept.stimuli.BiphasicPulseTrain`,\nyou can also load conventional images and convert them to stimuli using\n:py:class:`~pulse2percept.stimuli.ImageStimulus`.\n\n## Loading an image\n\nAn image can be loaded as follows:\n\n.. code:: python\n\n    stim = ImageStimulus('path-to-image.png')\n\nBy default, each pixel in the image is assigned to an electrode, and its\ngrayscale value is encoded as an amplitude.\nIf the image has more than 1 channel (e.g., RGB, RGBA), the image is flattened\nbefore each pixel/channel is assigned a different electrode.\nYou can specify names for the electrodes, but the number of electrodes must\nmatch the number of pixels. By default, electrodes are labeled 1...N.\n\nA number of images come pre-installed with pulse2percept, such as the logo of\nthe Bionic Vision Lab (BVL) at UC Santa Barbara:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import pulse2percept as p2p\nimport numpy as np\n\nlogo = p2p.stimuli.LogoBVL()\nprint(logo)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Inspecting the ``LogoBVL`` object, we can see that gray levels are converted\nto floats in the range [0, 1], and that the original 576x720x4 image is\nflattened so that each pixel can be assigned to an electrode.\n\nWe also notice that ``time=None``, indicating that the stimulus does not have\na time component. Thus we cannot apply temporal models to it.\n\n``LogoBVL`` can be assigned to a stimulus and used in conjunction with a\nphosphene model, just like any other\n:py:class:`~pulse2percept.stimuli.Stimulus` object.\n\n## Preprocessing an image\n\n:py:class:`~pulse2percept.stimuli.ImageStimulus` objects come with a number\nof methods to process an image before it is passed to an implant. We can:\n\n-  :py:meth:`~pulse2percept.stimuli.ImageStimulus.invert` the\n   polarity of the image (applied to all channels except the alpha channel),\n-  convert RGB and RGBA images to grayscale using\n   :py:meth:`~pulse2percept.stimuli.ImageStimulus.rgb2gray`\n   (note that a change in the number of pixels also means a change in the\n   number of electrodes),\n-  :py:meth:`~pulse2percept.stimuli.ImageStimulus.resize` the image\n   to a new height x width (optionally using anti-aliasing),\n-  :py:meth:`~pulse2percept.stimuli.ImageStimulus.scale`,\n   :py:meth:`~pulse2percept.stimuli.ImageStimulus.shift`, and\n   :py:meth:`~pulse2percept.stimuli.ImageStimulus.rotate` the image\n   foreground (i.e., anything that's not black),\n-  :py:meth:`~pulse2percept.stimuli.ImageStimulus.trim` any black borders\n   around the image.\n-  :py:meth:`~pulse2percept.stimuli.ImageStimulus.threshold` the image using\n   a number of commonly used techniques (e.g., Otsu's method, adaptive\n   thresholding, ISODATA),\n-  :py:meth:`~pulse2percept.stimuli.ImageStimulus.filter` the image and\n   extract edges (e.g., Sobel, Scharr, Canny, median filter),\n-  :py:meth:`~pulse2percept.stimuli.ImageStimulus.apply` any input-output\n   function not covered above (must accept an image as input and return\n   another image of the same size).\n\nCollectively, these methods should support arbitrarily complex image\npreprocessing strategies, including the ones commonly used by implants such\nas Argus II and Alpha-AMS.\n\nLet's look at a concrete example.\nTo get the BVL logo into proper shape, we need to convert the 4-channel RGBA\nimage to grayscale. This can be done with\n:py:meth:`~pulse2percept.stimuli.ImageStimulus.rgb2gray`.\nIn addition, since grayscale values will be mapped to current ampltiudes,\nwe may want to :py:meth:`~pulse2percept.stimuli.ImageStimulus.invert` the\nimage so that image edges appear bright on a dark background.\n\nWe can perform both actions in one line, and plot the result side-by-side\nwith the original image:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "logo_gray = logo.invert().rgb2gray()\n\nimport matplotlib.pyplot as plt\nfig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(8, 4))\nlogo.plot(ax=ax1)\nlogo_gray.plot(ax=ax2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As demonstrated above, multiple image processing steps can be performed in\none line. This is possible because each method returns a copy of the\nprocessed image (without altering the original).\n\nThe following example takes the grayscale logo, shrinks it to 75% of its\noriginal size, rotates it by 30 degrees (counter-clockwise), and trims the\nblack border around the image:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "logo_gray.scale(0.75).rotate(30).trim().plot()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As mentioned in the introduction above, the\n:py:meth:`~pulse2percept.stimuli.ImageStimulus.filter` method provides\na number of popular techniques to extract edges from the image, such as:\n\n-  ``'sobel'`` to extract edges using the [Sobel operator](https://scikit-image.org/docs/stable/api/skimage.filters.html#skimage.filters.sobel),\n-  ``'scharr'`` to extract edges using the [Scharr operator](https://scikit-image.org/docs/stable/api/skimage.filters.html#skimage.filters.scharr),\n   and\n-  ``'canny'`` to extract edges using the [Canny algorithm](https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny).\n\nAdditional parameters (e.g., the standard deviation of the Gaussian filter\nfor the Canny algorithm) can be passed as keyword arguments (e.g.,\n``filter('canny', sigma=3)``).\n\nFor example, we can use the Scharr operator as follows:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "logo_edge = logo_gray.filter('scharr')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If more advanced image processing methods are required, we can use the\n:py:meth:`~pulse2percept.stimuli.ImageStimulus.apply` method to apply\nliterally any function to the image. The only requirement is that the\nfunction return an image of the same size.\n\nFor example, we can thicken the edges in the image by using a morphological\noperator (i.e., dilation) provided by\n[scikit-image](https://scikit-image.org):\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from skimage.morphology import dilation\nlogo_dilate = logo_edge.apply(dilation)\n\nfig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(8, 4))\n# Edges extracted with the Scharr operator:\nlogo_edge.plot(ax=ax1)\n# Edges thickened with dilation:\nlogo_dilate.plot(ax=ax2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can also save the processed stimulus as an image:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "logo_dilate.save('dilated_logo.png')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Using the image as input to a retinal implant\n\n:py:class:`~pulse2percept.stimuli.ImageStimulus` can be used in\ncombination with any :py:meth:`~pulse2percept.implants.ProsthesisSystem`.\nWe just have to resize the image first so that the number of pixels in the\nimage matches the number of electrodes in the implant.\n\nBut let's start from the top. The first two steps are to create a model and\nchoose an implant:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Simulate only what we need (14x14 deg sampled at 0.1 deg):\nmodel = p2p.models.ScoreboardModel(xrange=(-7, 7), yrange=(-7, 7), xystep=0.1)\nmodel.build()\n\nfrom pulse2percept.implants import AlphaAMS\nimplant = AlphaAMS()\n\n# Show the visual field we're simulating (dashed lines) atop the implant:\nmodel.plot()\nimplant.plot()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Since :py:class:`~pulse2percept.implants.AlphaAMS` is a 2D electrode grid,\nall we need to do is downscale the image to the size of the grid:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "implant.stim = logo_gray.resize(implant.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This way, the pixels of the image will be assigned to the electrodes in\nrow-by-row order (i.e., we don't need to specify the actual electrode names).\n\n.. note ::\n\n   If the implant is not a proper 2D grid, you will have to manually specify\n   the input to each electrode.\n\n   In the near future, this will be done automatically using an implant's\n   ``preprocess`` method.\n\nThen the implant can be passed to the model's\n:py:meth:`~pulse2percept.models.ScoreboardModel.predict_percept` method:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "percept_gray = model.predict_percept(implant)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        ".. note ::\n\n    Because neither :py:class:`~pulse2percept.stimuli.ImageStimulus` nor\n    :py:class:`~pulse2percept.models.ScoreboardModel` can handle time, the\n    resulting percept will consist of a single image/frame.\n\nTo see what difference our image preprocessing makes on the quality of the\nresulting percept, we can re-run the model on ``logo_dilate`` and plot the\ntwo percepts side-by-side:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "implant.stim = logo_dilate.trim().resize(implant.shape)\npercept_dilate = model.predict_percept(implant)\n\nfig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 4))\npercept_gray.plot(ax=ax1)\npercept_dilate.plot(ax=ax2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Converting the image to a series of electrical pulses\n\n:py:class:`~pulse2percept.stimuli.ImageStimulus` has an\n:py:meth:`~pulse2percept.stimuli.ImageStimulus.encode` method\nto convert an image into a series of pulse trains (i.e., into electrical\nstimuli with a time component).\n\nBy default, the ``encode`` method will interpret the gray level of a pixel as\nthe current amplitude of a :py:class:`~pulse2percept.stimuli.BiphasicPulse`\nwith 0.46ms phase duration (500ms total stimulus duration). Gray levels in\nthe range [0, 1] will be mapped onto currents in the range [0, 50] uA:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "implant.stim = logo_dilate.trim().resize(implant.shape).encode()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can customize the range of amplitudes to be used by passing a keyword\nargument; e.g. ``amp_range=(0, 20)`` to use currents in [0, 20] uA.\n\nWe can also specify our own pulse / pulse train to be used. First, we need to\ncreate the pulse we want to use (use amplitude 1 uA). Then, we need to pass\nit as an additional keyword argument; e.g.,\n``pulse=BiphasicPulseTrain(10, 1, 0.2, stim_dur=200)`` to use a 10Hz\nbiphasic pulse train (0.2ms phase duration, overall duration 200 ms).\n\n## Using the image as input to a spatiotemporal model\n\nNow, if we passed the new stimulus to\n:py:class:`~pulse2percept.models.ScoreboardModel`, it would simply apply the\nmodel (in space) to every time point in the stimulus.\nTo get a proper temporal response, we need to extend the scoreboard model\nwith a proper temporal model, such as\n:py:class:`~pulse2percept.models.Horsager2009Temporal`:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = p2p.models.Model(spatial=p2p.models.ScoreboardSpatial,\n                         temporal=p2p.models.Horsager2009Temporal)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-info\"><h4>Note</h4><p>You can combine any spatial model (names ending in **Spatial**) with any\n   temporal model (names ending in **Temporal**).</p></div>\n\nTo make the model focus on the same visual field as above, we set ``xrange``,\n``yrange``, and choose a proper ``xystep``.\n\nThe ``rho`` parameter of the scoreboard model controls how much blur we get\nin the resulting percept. The value of this parameter should be set\nempirically to match the quality of the vision reported behaviorally by each\nimplant user.\nFor the purpose of this tutorial, we will set it to 50um:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model.build(xrange=(-7, 7), yrange=(-7, 7), xystep=0.1, rho=50)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The predicted percept will now be a movie, where the spatial response (i.e.,\neach frame of the movie) is primarily determined by the scoreboard model, but\nthe temporal evolution of these frames is determined by the Horsager model.\n\nBy default, the model will output a movie frame every 20 ms (corresponding to\na 50 Hz frame rate). The frame rate can be adjusted by passing a list of\ntime points to :py:meth:`~pulse2percept.Model.predict_percept` (e.g.,\n``t_percept=np.arange(500)`` to get an output every millisecond):\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "percept = model.predict_percept(implant)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The output of the model is a :py:class:`~pulse2percept.percepts.Percept`\nobject, which can be animated in IPython or Jupyter Notebook using the\n:py:meth:`~pulse2percept.percepts.Percept.play` method:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "percept.play()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "You can also save the percept as a movie:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "percept.save('logo_percept.mp4')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}