Source code for wizard._exploration.surface

"""
surface.py
==========

.. module:: surface
:platform: Unix
:synopsis: Surface plotting and manipulation module for the hsi-wizard package.

Module Overview
--------------

This module provides functionalities for manipulating and visualizing data cubes.
It includes utilities for slicing, cutting, and plotting 3D surfaces interactively
using sliders.

"""

from wizard import DataCube
import numpy as np
import copy
from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.widgets import Slider


def dc_cut_by_value(z: np.array, val: int, type: str) -> np.array:
    """
    Threshold a 2D data slice by a normalized cut-off value.

    This function normalizes the input array to its maximum value, applies a threshold
    such that all values less than or equal to the cut-off are replaced with the array's minimum,
    and preserves values above the threshold. Use this to mask out low-intensity regions
    in a hyperspectral data slice.

    Parameters
    ----------
    z : np.ndarray
        2D array (spatial slice) extracted from a DataCube (shape: (x, y)).
    val : float
        Normalized threshold between 0.0 and 1.0. Values less than or equal to this threshold
        are set to the minimum value of the normalized slice.
    type : str, optional
        Label for the type of cut applied; currently unused but reserved for future modes.

    Returns
    -------
    np.ndarray
        A copy of the input array after normalization and threshold cut.


    Notes
    -----
    - The input array is deeply copied to avoid modifying original data.
    - Thresholding is performed after normalizing by the array's maximum.

    Examples
    --------
    >>> slice2D = np.array([[0, 5], [10, 15]])
    >>> dc_cut_by_value(slice2D, 0.5)
    array([[0. , 0. ], [0.66666667, 1. ]])
    """
    new_z = copy.deepcopy(z)
    new_z /= new_z.max()
    new_z[new_z <= val] = new_z.min()
    return new_z


def get_z_surface(cube: np.array, v: int) -> np.array:
    """
    Extract the positive-valued surface from a spectral slice of a data cube.

    This function selects the v-th spectral band from a 3D hyperspectral cube,
    masks out non-positive values, and constructs a 2D surface array for plotting.
    Use this to visualize spatial distribution at a given wavelength index.

    Parameters
    ----------
    cube : np.ndarray
        3D data cube with shape (v, x, y), where v is the number of spectral bands.
    v : int
        Index of the spectral band to extract (0 <= v < cube.shape[0]).

    Returns
    -------
    np.ndarray
        2D array (shape: (x, y)) containing only the positive values from the slice;
        non-positive entries are left at zero.


    Examples
    --------
    >>> cube = np.zeros((3, 2, 2))
    >>> cube[1] = [[1, -1], [0, 2]]
    >>> get_z_surface(cube, 1)
    array([[1., 0.], [0., 2.]])
    """
    z = np.zeros((cube.shape[1], cube.shape[2]))
    slice_v = cube[v, :, :]
    mask = slice_v > 0
    z[mask] = slice_v[mask]
    return z


[docs] def plot_surface(dc: DataCube, index: int = 0): """ Create an interactive 3D surface plot from a DataCube slice. This function visualizes a DataCube by plotting a 3D surface of a selected spectral band. Users can manipulate two sliders: one to change the wavelength index (spectral band) and another to adjust the normalized cut-off threshold, updating the plot in real time. Parameters ---------- dc : DataCube DataCube instance index : int, optional Initial spectral index to display (default is 0). Returns ------- None Notes ----- - The wavelength slider is labeled with actual wavelength values at evenly spaced ticks. - The cut-off slider normalizes data slices between 0 and 1 before thresholding. - Use the sliders to explore spectral variation and mask out low-intensity regions. Examples -------- >>> import wizard >>> dc = wizard.DataCube(np.random.rand(10, 100, 100), wavelengths=list(np.linspace(400, 700, 10))) >>> wizard.plot_surface(dc, index=5) """ def update(val): idx = int(slider.val) # Ensure integer values cut_val = slider_cut.val # Get cut value ax.clear() z = get_z_surface(dc, idx) # Apply the cut before getting the surface data z = dc_cut_by_value(z, cut_val, type="") x, y = np.meshgrid(range(dc.shape[1]), range(dc.shape[2])) ax.plot_surface(x, y, z.T, cmap=cm.coolwarm) ax.set_title(f'{dc.name if dc.name else ""} @{dc.wavelengths[idx]:.2f} {dc.notation if dc.notation else ""}') ax.set(xlabel='x', ylabel='y', zlabel='counts') fig.canvas.draw_idle() fig = plt.figure() ax = fig.add_subplot(111, projection='3d') # Wavelength slider slider_ax = fig.add_axes([0.2, 0.02, 0.6, 0.03], facecolor='lightgoldenrodyellow') tick_positions = np.linspace(0, len(dc.wavelengths) - 1, min(5, len(dc.wavelengths))).astype(int) slider = Slider(slider_ax, 'Wavelength', 0, dc.shape[0] - 1, valinit=index, valstep=1) slider.ax.set_xticks(tick_positions) slider.ax.set_xticklabels([f'{dc.wavelengths[i]:.2f}' for i in tick_positions]) slider.on_changed(update) # Cut value slider slider_cut_ax = fig.add_axes([0.2, 0.06, 0.6, 0.03], facecolor='lightgoldenrodyellow') slider_cut = Slider(slider_cut_ax, 'Cut Value', 0, 1, valinit=0, valstep=0.01) # Cut value from 0 to 1 slider_cut.on_changed(update) update(index) # Initial plot plt.show()