| import numpy as np |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| from ssms.config import get_model_config |
| from ssms.basic_simulators import Simulator |
|
|
| _model_config = get_model_config() |
| from matplotlib.lines import Line2D |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| _VARIABILITY_COLLAPSE = { |
| "ddm_sdv": {"sv": 0.0}, |
| "ddm_st": {"st": 0.0}, |
| "ddm_truncnormt": {"st": 1e-9}, |
| "ddm_rayleight": {"st": 0.0}, |
| } |
|
|
| |
| |
| _EXPECTED_T_POST_SHIFT = { |
| "ddm_rayleight": lambda theta_d: theta_d["st"] * np.sqrt(np.pi / 2.0), |
| } |
|
|
|
|
| def _collapse_theta_for_cartoon(model_name, theta): |
| """Return a copy of `theta` with random-variability params replaced for the cartoon-only path. |
| |
| `theta` is the list-of-lists shape that `app.py` passes to `Simulator.simulate`: positional |
| in `_model_config[model_name]['params']` order. Models without an entry in the table are |
| returned unchanged. |
| """ |
| overrides = _VARIABILITY_COLLAPSE.get(model_name) |
| if not overrides: |
| return theta |
| params = _model_config[model_name]["params"] |
| inner = list(theta[0]) |
| for name, value in overrides.items(): |
| inner[params.index(name)] = value |
| return [inner] |
|
|
|
|
| def _apply_expected_t_shift(model_name, theta_dict, sim_out): |
| """Mutate `sim_out['metadata']['t']` in place to reflect E[NDT] for models that need it. |
| |
| Required for distributions whose param=0 collapse does not equal the expectation |
| (currently only ddm_rayleight: Rayleigh(0) = 0 but E[Rayleigh(scale=st)] = st·sqrt(pi/2)). |
| """ |
| shift_fn = _EXPECTED_T_POST_SHIFT.get(model_name) |
| if shift_fn is None: |
| return |
| shift = shift_fn(theta_dict) |
| t_arr = np.asarray(sim_out["metadata"]["t"]) |
| sim_out["metadata"]["t"] = (t_arr + shift).astype(t_arr.dtype, copy=False) |
|
|
|
|
| def mean_rt_per_choice(sim_output): |
| """Return {choice_value: mean_rt} over valid (non-deadline-timeout) trials. |
| |
| Filters samples where rts == -999 or choices == -999 (deadline timeouts / |
| invalid responses). Choice values are returned as int when integer-valued, |
| else as float, so the dict keys round-trip through DataFrame column names |
| cleanly. |
| |
| Returns an empty dict if all samples are invalid. |
| """ |
| rts = np.asarray(sim_output["rts"]).flatten() |
| choices = np.asarray(sim_output["choices"]).flatten() |
| valid = (rts != -999) & (choices != -999.0) |
| if not valid.any(): |
| return {} |
| rts_v = rts[valid] |
| choices_v = choices[valid] |
| out = {} |
| for c in np.unique(choices_v): |
| key = int(c) if float(c).is_integer() else float(c) |
| out[key] = float(rts_v[choices_v == c].mean()) |
| return out |
|
|
|
|
| |
| |
| def _patch_trajectory_t_with_actual_ndt(sim_out, delta_t=0.001): |
| """Override `sim_out['metadata']['t']` with the per-sample NDT actually used for this trajectory. |
| |
| The simulator records the *input* `t` in metadata, not the per-sample draw from `t_dist` |
| (e.g. Uniform/TruncNormal/Rayleigh for ddm_st / ddm_truncnormt / ddm_rayleight). |
| Without this patch, every trajectory plotter starts at the input `t` and ignores NDT |
| variability. We back-derive the per-sample NDT from `RT - decision_time`, where |
| `decision_time = (last valid trajectory index) * delta_t`. Trajectory calls in this |
| module use `smooth_unif=False`, so `RT = decision_time + NDT` exactly. |
| |
| No-op for models without NDT variability (actual_ndt == input t). |
| """ |
| traj = np.asarray(sim_out["metadata"]["trajectory"]).flatten() |
| valid_idx = np.where(traj > -999)[0] |
| decision_time = float(valid_idx[-1] * delta_t) if len(valid_idx) else 0.0 |
| rt = float(np.asarray(sim_out["rts"]).flat[0]) |
| actual_ndt = max(0.0, rt - decision_time) |
| t_arr = np.asarray(sim_out["metadata"]["t"]) |
| sim_out["metadata"]["t"] = np.full_like(t_arr, actual_ndt, dtype=t_arr.dtype) |
|
|
|
|
| def plot_func_model( |
| model_name, |
| theta, |
| axis, |
| value_range=None, |
| n_samples=10, |
| bin_size=0.05, |
| add_data_rts=True, |
| add_data_model_keep_slope=True, |
| add_data_model_keep_boundary=True, |
| add_data_model_keep_ndt=True, |
| add_data_model_keep_starting_point=True, |
| add_data_model_markersize_starting_point=50, |
| add_data_model_markertype_starting_point=0, |
| add_data_model_markershift_starting_point=0, |
| n_trajectories = 0, |
| linewidth_histogram=0.5, |
| linewidth_model=0.5, |
| legend_fontsize=12, |
| legend_shadow=True, |
| legend_location="upper right", |
| data_color="blue", |
| posterior_uncertainty_color="black", |
| alpha=0.05, |
| delta_t_model=0.001, |
| random_state=None, |
| add_legend=True, |
| expected_random_params=True, |
| **kwargs, |
| ): |
| """Calculate posterior predictive for a certain bottom node. |
| |
| Arguments: |
| bottom_node: pymc.stochastic |
| Bottom node to compute posterior over. |
| |
| axis: matplotlib.axis |
| Axis to plot into. |
| |
| value_range: numpy.ndarray |
| Range over which to evaluate the likelihood. |
| |
| Optional: |
| samples: int <default=10> |
| Number of posterior samples to use. |
| |
| bin_size: float <default=0.05> |
| Size of bins used for histograms |
| |
| alpha: float <default=0.05> |
| alpha (transparency) level for the sample-wise elements of the plot |
| |
| add_data_rts: bool <default=True> |
| Add data histogram of rts ? |
| |
| add_data_model: bool <default=True> |
| Add model cartoon for data |
| |
| add_posterior_uncertainty_rts: bool <default=True> |
| Add sample by sample histograms? |
| |
| add_posterior_mean_rts: bool <default=True> |
| Add a mean posterior? |
| |
| add_model: bool <default=True> |
| Whether to add model cartoons to the plot. |
| |
| linewidth_histogram: float <default=0.5> |
| linewdith of histrogram plot elements. |
| |
| linewidth_model: float <default=0.5> |
| linewidth of plot elements concerning the model cartoons. |
| |
| legend_location: str <default='upper right'> |
| string defining legend position. Find the rest of the options in the matplotlib documentation. |
| |
| legend_shadow: bool <default=True> |
| Add shadow to legend box? |
| |
| legend_fontsize: float <default=12> |
| Fontsize of legend. |
| |
| data_color : str <default="blue"> |
| Color for the data part of the plot. |
| |
| posterior_mean_color : str <default="red"> |
| Color for the posterior mean part of the plot. |
| |
| posterior_uncertainty_color : str <default="black"> |
| Color for the posterior uncertainty part of the plot. |
| |
| delta_t_model: |
| specifies plotting intervals for model cartoon elements of the graphs. |
| """ |
|
|
| if value_range is None: |
| |
| raise NotImplementedError("value_range keyword argument must be supplied.") |
|
|
| if len(value_range) > 2: |
| value_range = (value_range[0], value_range[-1]) |
|
|
| |
| bins = np.arange(value_range[0], value_range[-1], bin_size) |
|
|
| if _model_config[model_name]["nchoices"] > 2: |
| raise ValueError("The model plot works only for 2 choice models at the moment") |
|
|
| |
| |
|
|
| |
| if random_state is not None: |
| np.random.seed(random_state) |
| |
| sim = Simulator(model=model_name) |
|
|
| rand_int = np.random.choice(400000000) |
| sim_out = sim.simulate(theta=theta, n_samples=n_samples, |
| no_noise=False, delta_t=0.001, random_state=rand_int) |
|
|
| sim_out_traj = {} |
| for i in range(n_trajectories): |
| rand_int = np.random.choice(400000000) |
| sim_out_traj[i] = sim.simulate(theta=theta, n_samples=1, |
| no_noise=False, delta_t=0.001, |
| random_state=rand_int, smooth_unif=False) |
| _patch_trajectory_t_with_actual_ndt(sim_out_traj[i], delta_t=0.001) |
|
|
| theta_cartoon = _collapse_theta_for_cartoon(model_name, theta) if expected_random_params else theta |
| sim_out_no_noise = sim.simulate(theta=theta_cartoon, n_samples=1, |
| no_noise=True, delta_t=0.001, |
| smooth_unif=False) |
| if expected_random_params: |
| params = _model_config[model_name]["params"] |
| theta_dict = dict(zip(params, theta[0])) |
| _apply_expected_t_shift(model_name, theta_dict, sim_out_no_noise) |
|
|
| |
| weights_up = np.tile( |
| (1 / bin_size) / sim_out['rts'][(sim_out['rts'] != -999)].shape[0], |
| reps=sim_out['rts'][(sim_out['rts'] != -999) & (sim_out['choices'] == 1)].shape[0], |
| ) |
| weights_down = np.tile( |
| (1 / bin_size) / sim_out['rts'][(sim_out['rts'] != -999)].shape[0], |
| reps=sim_out['rts'][(sim_out['rts'] != -999) & (sim_out['choices'] != 1)].shape[0], |
| ) |
|
|
| (b_high, b_low) = (np.maximum(sim_out['metadata']['boundary'], 0), |
| np.minimum((-1) * sim_out['metadata']['boundary'], 0)) |
|
|
| |
| |
|
|
| ylim = kwargs.pop("ylim", 3) |
| |
| hist_histtype = kwargs.pop("hist_histtype", "step") |
|
|
| if ("ylim_high" in kwargs) and ("ylim_low" in kwargs): |
| ylim_high = kwargs["ylim_high"] |
| ylim_low = kwargs["ylim_low"] |
| else: |
| ylim_high = ylim |
| ylim_low = -ylim |
|
|
| if ("hist_bottom_high" in kwargs) and ("hist_bottom_low" in kwargs): |
| hist_bottom_high = kwargs["hist_bottom_high"] |
| hist_bottom_low = kwargs["hist_bottom_low"] |
| else: |
| hist_bottom_high = b_high[0] |
| hist_bottom_low = -b_low[0] |
|
|
| axis.set_xlim(value_range[0], value_range[-1]) |
| axis.set_ylim(ylim_low, ylim_high) |
| axis_twin_up = axis.twinx() |
| axis_twin_down = axis.twinx() |
| axis_twin_up.set_ylim(ylim_low, ylim_high) |
| axis_twin_up.set_yticks([]) |
| axis_twin_down.set_ylim(ylim_high, ylim_low) |
| axis_twin_down.set_yticks([]) |
| axis_twin_down.set_axis_off() |
| axis_twin_up.set_axis_off() |
|
|
| axis_twin_up.hist( |
| np.abs(sim_out['rts'][(sim_out['rts'] != -999) & (sim_out['choices'] == 1)]), |
| bins=bins, |
| weights=weights_up, |
| histtype=hist_histtype, |
| bottom=hist_bottom_high, |
| alpha=alpha, |
| color=data_color, |
| edgecolor=data_color, |
| linewidth=linewidth_histogram, |
| zorder=-1, |
| ) |
|
|
| axis_twin_down.hist( |
| np.abs(sim_out['rts'][(sim_out['rts'] != -999) & (sim_out['choices'] != 1)]), |
| bins=bins, |
| weights=weights_down, |
| histtype=hist_histtype, |
| bottom=hist_bottom_low, |
| alpha=alpha, |
| color=data_color, |
| edgecolor=data_color, |
| linewidth=linewidth_histogram, |
| zorder=-1, |
| ) |
|
|
| |
| j = 0 |
| t_s = np.arange(0, sim_out['metadata']['max_t'], delta_t_model) |
|
|
| _add_model_cartoon_to_ax( |
| sample=sim_out_no_noise, |
| axis=axis, |
| keep_slope=add_data_model_keep_slope, |
| keep_boundary=add_data_model_keep_boundary, |
| keep_ndt=add_data_model_keep_ndt, |
| keep_starting_point=add_data_model_keep_starting_point, |
| markersize_starting_point=add_data_model_markersize_starting_point, |
| markertype_starting_point=add_data_model_markertype_starting_point, |
| markershift_starting_point=add_data_model_markershift_starting_point, |
| delta_t_graph=delta_t_model, |
| sample_hist_alpha=alpha, |
| lw_m=linewidth_model, |
| ylim_low=ylim_low, |
| ylim_high=ylim_high, |
| t_s=t_s, |
| color=posterior_uncertainty_color, |
| zorder_cnt=j, |
| ) |
|
|
| if n_trajectories > 0: |
| _add_trajectories( |
| axis=axis, |
| sample=sim_out_traj, |
| t_s=t_s, |
| delta_t_graph=delta_t_model, |
| n_trajectories=n_trajectories, |
| **kwargs, |
| ) |
| |
| return axis |
|
|
| |
| def _add_trajectories( |
| axis=None, |
| sample=None, |
| t_s=None, |
| delta_t_graph=0.01, |
| n_trajectories=10, |
| supplied_trajectory=None, |
| maxid_supplied_trajectory=1, |
| highlight_trajectory_rt_choice=True, |
| markersize_trajectory_rt_choice=50, |
| markertype_trajectory_rt_choice="*", |
| markercolor_trajectory_rt_choice="red", |
| linewidth_trajectories=1, |
| alpha_trajectories=0.5, |
| color_trajectories="black", |
| **kwargs, |
| ): |
| """Add trajectories to a given axis.""" |
| |
| if isinstance(markercolor_trajectory_rt_choice, str): |
| markercolor_trajectory_rt_choice_dict = {} |
| for value_ in sample[0]['metadata']['possible_choices']: |
| markercolor_trajectory_rt_choice_dict[ |
| value_ |
| ] = markercolor_trajectory_rt_choice |
| elif isinstance(markercolor_trajectory_rt_choice, list): |
| cnt = 0 |
| for value_ in sample[0]['metadata']['possible_choices']: |
| markercolor_trajectory_rt_choice_dict[ |
| value_ |
| ] = markercolor_trajectory_rt_choice[cnt] |
| cnt += 1 |
| elif isinstance(markercolor_trajectory_rt_choice, dict): |
| markercolor_trajectory_rt_choice_dict = markercolor_trajectory_rt_choice |
| else: |
| pass |
|
|
| |
| if isinstance(color_trajectories, str): |
| color_trajectories_dict = {} |
| for value_ in sample[0]['metadata']['possible_choices']: |
| color_trajectories_dict[value_] = color_trajectories |
| elif isinstance(color_trajectories, list): |
| cnt = 0 |
| for value_ in sample[0]['metadata']['possible_choices']: |
| color_trajectories_dict[value_] = color_trajectories[cnt] |
| cnt += 1 |
| elif isinstance(color_trajectories, dict): |
| color_trajectories_dict = color_trajectories |
| else: |
| pass |
|
|
| |
| (b_high, b_low) = (np.maximum(sample[0]['metadata']['boundary'], 0), |
| np.minimum((-1) * sample[0]['metadata']['boundary'], 0)) |
| |
| b_h_init = b_high[0] |
| b_l_init = b_low[0] |
| n_roll = int(np.asarray(sample[0]['metadata']['t']).flat[0] / delta_t_graph + 1) |
| b_high = np.roll(b_high, n_roll) |
| b_high[:n_roll] = b_h_init |
| b_low = np.roll(b_low, n_roll) |
| b_low[:n_roll] = b_l_init |
|
|
| |
| for i in range(n_trajectories): |
| tmp_traj = sample[i]['metadata']['trajectory'] |
| tmp_traj_choice = sample[i]['choices'].flatten().item() |
| maxid = np.minimum(np.argmax(np.where(tmp_traj > -999)), t_s.shape[0]) |
|
|
| |
| b_tmp = b_high[maxid + n_roll] if tmp_traj_choice > 0 else b_low[maxid + n_roll] |
|
|
| axis.plot( |
| t_s[:maxid] + sample[i]['metadata']['t'][0], |
| tmp_traj[:maxid], |
| color=color_trajectories_dict[tmp_traj_choice], |
| alpha=alpha_trajectories, |
| linewidth=linewidth_trajectories, |
| zorder=2000 + i, |
| ) |
|
|
| if highlight_trajectory_rt_choice: |
| axis.scatter( |
| t_s[maxid] + sample[i]['metadata']['t'][0], |
| b_tmp, |
| |
| markersize_trajectory_rt_choice, |
| color=markercolor_trajectory_rt_choice_dict[tmp_traj_choice], |
| alpha=1, |
| marker=markertype_trajectory_rt_choice, |
| zorder=2000 + i, |
| ) |
|
|
| |
| def _add_model_cartoon_to_ax( |
| sample=None, |
| axis=None, |
| keep_slope=True, |
| keep_boundary=True, |
| keep_ndt=True, |
| keep_starting_point=True, |
| markersize_starting_point=80, |
| markertype_starting_point=1, |
| markershift_starting_point=-0.05, |
| delta_t_graph=None, |
| sample_hist_alpha=None, |
| lw_m=None, |
| tmp_label=None, |
| ylim_low=None, |
| ylim_high=None, |
| t_s=None, |
| zorder_cnt=1, |
| color="black", |
| ): |
| |
| (b_high, b_low) = (np.maximum(sample['metadata']['boundary'], 0), |
| np.minimum((-1) * sample['metadata']['boundary'], 0)) |
|
|
| b_h_init = b_high[0] |
| b_l_init = b_low[0] |
| n_roll = int(np.asarray(sample['metadata']['t']).flat[0] / delta_t_graph + 1) |
| b_high = np.roll(b_high, n_roll) |
| b_high[:n_roll] = b_h_init |
| b_low = np.roll(b_low, n_roll) |
| b_low[:n_roll] = b_l_init |
|
|
| tmp_traj = sample["metadata"]["trajectory"] |
| maxid = np.minimum(np.argmax(np.where(tmp_traj > -999)), |
| t_s.shape[0]) |
|
|
| if keep_boundary: |
| |
| axis.plot( |
| t_s, |
| b_high[:t_s.shape[0]], |
| color=color, |
| alpha=1, |
| zorder=1000 + zorder_cnt, |
| linewidth=lw_m, |
| label=tmp_label, |
| ) |
|
|
| |
| axis.plot( |
| t_s, |
| b_low[:t_s.shape[0]], |
| color=color, |
| alpha=1, |
| zorder=1000 + zorder_cnt, |
| linewidth=lw_m, |
| ) |
|
|
| |
| if keep_slope: |
| axis.plot( |
| t_s[:maxid] + sample['metadata']['t'][0], |
| tmp_traj[:maxid], |
| color=color, |
| alpha=1, |
| zorder=1000 + zorder_cnt, |
| linewidth=lw_m, |
| ) |
|
|
| |
| if keep_ndt: |
| axis.axvline( |
| x=sample['metadata']['t'][0], |
| ymin=ylim_low, |
| ymax=ylim_high, |
| color=color, |
| linestyle="--", |
| linewidth=lw_m, |
| zorder=1000 + zorder_cnt, |
| alpha=1, |
| ) |
| |
| if keep_starting_point: |
| axis.scatter( |
| sample['metadata']['t'][0] + markershift_starting_point, |
| b_low[0] + (sample['metadata']['z'][0] * (b_high[0] - b_low[0])), |
| s=markersize_starting_point, |
| marker=markertype_starting_point, |
| color=color, |
| alpha=1, |
| zorder=1000 + zorder_cnt, |
| ) |
|
|
| def plot_func_model_n( |
| model_name, |
| theta, |
| axis, |
| n_trajectories=10, |
| value_range=None, |
| bin_size=0.05, |
| n_samples=10, |
| linewidth_histogram=0.5, |
| linewidth_model=0.5, |
| legend_fontsize=7, |
| legend_shadow=True, |
| legend_location="upper right", |
| delta_t_model=0.001, |
| add_legend=True, |
| alpha=1, |
| keep_frame=False, |
| random_state=None, |
| expected_random_params=True, |
| **kwargs, |
| ): |
| """Calculate posterior predictive for a certain bottom node. |
| |
| Arguments: |
| bottom_node: pymc.stochastic |
| Bottom node to compute posterior over. |
| |
| axis: matplotlib.axis |
| Axis to plot into. |
| |
| value_range: numpy.ndarray |
| Range over which to evaluate the likelihood. |
| |
| Optional: |
| samples: int <default=10> |
| Number of posterior samples to use. |
| |
| bin_size: float <default=0.05> |
| Size of bins used for histograms |
| |
| alpha: float <default=1.0> |
| alpha (transparency) level for the sample-wise elements of the plot |
| |
| add_posterior_uncertainty_rts: bool <default=True> |
| Add sample by sample histograms? |
| |
| add_posterior_mean_rts: bool <default=True> |
| Add a mean posterior? |
| |
| add_model: bool <default=True> |
| Whether to add model cartoons to the plot. |
| |
| linewidth_histogram: float <default=0.5> |
| linewdith of histrogram plot elements. |
| |
| linewidth_model: float <default=0.5> |
| linewidth of plot elements concerning the model cartoons. |
| |
| legend_loc: str <default='upper right'> |
| string defining legend position. Find the rest of the options in the matplotlib documentation. |
| |
| legend_shadow: bool <default=True> |
| Add shadow to legend box? |
| |
| legend_fontsize: float <default=12> |
| Fontsize of legend. |
| |
| data_color : str <default="blue"> |
| Color for the data part of the plot. |
| |
| posterior_mean_color : str <default="red"> |
| Color for the posterior mean part of the plot. |
| |
| posterior_uncertainty_color : str <default="black"> |
| Color for the posterior uncertainty part of the plot. |
| |
| |
| delta_t_model: |
| specifies plotting intervals for model cartoon elements of the graphs. |
| """ |
|
|
| color_dict = { |
| -1: "black", |
| 0: "black", |
| 1: "green", |
| 2: "blue", |
| 3: "red", |
| 4: "orange", |
| 5: "purple", |
| 6: "brown", |
| } |
|
|
| |
| if value_range is None: |
| |
| raise NotImplementedError("value_range keyword argument must be supplied.") |
|
|
| if len(value_range) > 2: |
| value_range = (value_range[0], value_range[-1]) |
|
|
| |
| bins = np.arange(value_range[0], value_range[-1], bin_size) |
| |
| ylim = kwargs.pop("ylim", 4) |
|
|
| axis.set_xlim(value_range[0], value_range[-1]) |
| axis.set_ylim(0, ylim) |
|
|
| |
|
|
| |
| |
|
|
| |
| if random_state is not None: |
| np.random.seed(random_state) |
| |
| sim = Simulator(model=model_name) |
|
|
| rand_int = np.random.choice(400000000) |
| sim_out = sim.simulate(theta=theta, n_samples=n_samples, |
| no_noise=False, delta_t=0.001, random_state=rand_int) |
|
|
| choices = sim_out['metadata']['possible_choices'] |
|
|
| sim_out_traj = {} |
| for i in range(n_trajectories): |
| rand_int = np.random.choice(400000000) |
| sim_out_traj[i] = sim.simulate(theta=theta, n_samples=1, |
| no_noise=False, delta_t=0.001, |
| random_state=rand_int, smooth_unif=False) |
| _patch_trajectory_t_with_actual_ndt(sim_out_traj[i], delta_t=0.001) |
|
|
| theta_cartoon = _collapse_theta_for_cartoon(model_name, theta) if expected_random_params else theta |
| sim_out_no_noise = sim.simulate(theta=theta_cartoon, n_samples=1, |
| no_noise=True, delta_t=0.001, |
| smooth_unif=False) |
| if expected_random_params: |
| params = _model_config[model_name]["params"] |
| theta_dict = dict(zip(params, theta[0])) |
| _apply_expected_t_shift(model_name, theta_dict, sim_out_no_noise) |
|
|
| |
| |
|
|
| |
| j = 0 |
| b = np.maximum(sim_out['metadata']['boundary'], 0) |
| bottom = b[0] |
| for choice in choices: |
| tmp_label = None |
|
|
| if add_legend and j == 0: |
| tmp_label = "PostPred" |
|
|
| weights = np.tile( |
| (1 / bin_size) / sim_out['rts'].shape[0], |
| reps=sim_out['rts'][(sim_out['choices'] == choice) & (sim_out['rts'] != -999)].shape[0], |
| ) |
|
|
| axis.hist( |
| np.abs(sim_out['rts'][(sim_out['choices'] == choice) & (sim_out['rts'] != -999)]), |
| bins=bins, |
| bottom=bottom, |
| weights=weights, |
| histtype="step", |
| alpha=alpha, |
| color=color_dict[choice], |
| zorder=-1, |
| label=tmp_label, |
| linewidth=linewidth_histogram, |
| ) |
| j += 1 |
|
|
| |
| tmp_label = None |
| j = 0 |
| t_s = np.arange(0, sim_out['metadata']['max_t'], delta_t_model) |
|
|
| if add_legend and (j == 0): |
| tmp_label = "PostPred" |
|
|
| _add_model_n_cartoon_to_ax( |
| sample=sim_out_no_noise, |
| axis=axis, |
| delta_t_graph=delta_t_model, |
| sample_hist_alpha=alpha, |
| lw_m=linewidth_model, |
| tmp_label=tmp_label, |
| linestyle="-", |
| ylim=ylim, |
| t_s=t_s, |
| color_dict=color_dict, |
| zorder_cnt=j, |
| ) |
|
|
| if n_trajectories > 0: |
| _add_trajectories_n( |
| axis=axis, |
| sample=sim_out_traj, |
| t_s=t_s, |
| delta_t_graph=delta_t_model, |
| n_trajectories=n_trajectories, |
| **kwargs, |
| ) |
|
|
| if add_legend: |
| custom_elems = [ |
| Line2D([0], [0], color=color_dict[choice], lw=1) for choice in choices |
| ] |
| custom_titles = ["response: " + str(choice) for choice in choices] |
|
|
| custom_elems.append( |
| Line2D([0], [0], color="black", lw=1.0, linestyle="dashed") |
| ) |
| |
| |
| |
|
|
| axis.legend( |
| custom_elems, |
| custom_titles, |
| fontsize=legend_fontsize, |
| shadow=legend_shadow, |
| loc=legend_location, |
| ) |
|
|
| |
| if not keep_frame: |
| axis.set_frame_on(False) |
|
|
| return axis |
|
|
| def _add_trajectories_n(axis=None, |
| sample=None, |
| t_s=None, |
| delta_t_graph=0.01, |
| n_trajectories=10, |
| highlight_trajectory_rt_choice=True, |
| markersize_trajectory_rt_choice=50, |
| markertype_trajectory_rt_choice="*", |
| markercolor_trajectory_rt_choice="black", |
| linewidth_trajectories=1, |
| alpha_trajectories=0.5, |
| color_trajectories="black", |
| **kwargs, |
| ): |
| |
| """Add trajectories to a given axis.""" |
| color_dict = { |
| -1: "black", |
| 0: "black", |
| 1: "green", |
| 2: "blue", |
| 3: "red", |
| 4: "orange", |
| 5: "purple", |
| 6: "brown", |
| } |
|
|
| |
| if isinstance(color_trajectories, str): |
| color_trajectories_dict = {} |
| for value_ in sample[0]['metadata']['possible_choices']: |
| color_trajectories_dict[value_] = color_trajectories |
| elif isinstance(color_trajectories, list): |
| cnt = 0 |
| for value_ in sample[0]['metadata']['possible_choices']: |
| color_trajectories_dict[value_] = color_trajectories[cnt] |
| cnt += 1 |
| elif isinstance(color_trajectories, dict): |
| color_trajectories_dict = color_trajectories |
| else: |
| pass |
|
|
| |
| b = np.maximum(sample[0]['metadata']['boundary'], 0) |
| b_init = b[0] |
| n_roll = int(np.asarray(sample[0]['metadata']['t']).flat[0] / delta_t_graph + 1) |
| b = np.roll(b, n_roll) |
| b[:n_roll] = b_init |
|
|
| |
| for i in range(n_trajectories): |
| tmp_traj = sample[i]['metadata']['trajectory'] |
| tmp_traj_choice = sample[i]['choices'].flatten().item() |
| n_traj_cols = tmp_traj.shape[1] if tmp_traj.ndim > 1 else 1 |
|
|
| for j in range(n_traj_cols): |
| tmp_maxid = np.minimum(np.argmax(np.where(tmp_traj[:, j] > -999)), t_s.shape[0]) |
|
|
| |
| b_tmp = b[tmp_maxid + n_roll] |
|
|
| axis.plot( |
| t_s[:tmp_maxid] + sample[i]['metadata']['t'][0], |
| tmp_traj[:tmp_maxid, j], |
| color=color_dict[j], |
| alpha=alpha_trajectories, |
| linewidth=linewidth_trajectories, |
| zorder=2000 + i, |
| ) |
|
|
| if highlight_trajectory_rt_choice and tmp_traj_choice == j: |
| axis.scatter( |
| t_s[tmp_maxid] + sample[i]['metadata']['t'][0], |
| b_tmp, |
| |
| markersize_trajectory_rt_choice, |
| color=color_dict[tmp_traj_choice], |
| alpha=1, |
| marker=markertype_trajectory_rt_choice, |
| zorder=2000 + i, |
| ) |
| elif highlight_trajectory_rt_choice and tmp_traj_choice != j: |
| axis.scatter( |
| t_s[tmp_maxid] + sample[i]['metadata']['t'][0] + 0.05, |
| tmp_traj[tmp_maxid, j], |
| |
| markersize_trajectory_rt_choice, |
| color=color_dict[j], |
| alpha=1, |
| marker=5, |
| zorder=2000 + i, |
| ) |
|
|
| def _add_model_n_cartoon_to_ax( |
| sample=None, |
| axis=None, |
| delta_t_graph=None, |
| sample_hist_alpha=None, |
| keep_boundary=True, |
| keep_ndt=True, |
| keep_slope=True, |
| keep_starting_point=True, |
| lw_m=None, |
| linestyle="-", |
| tmp_label=None, |
| ylim=None, |
| t_s=None, |
| zorder_cnt=1, |
| color_dict=None, |
| ): |
| |
| b = np.maximum(sample['metadata']['boundary'], 0) |
| b_init = b[0] |
| n_roll = int(np.asarray(sample['metadata']['t']).flat[0] / delta_t_graph + 1) |
| b = np.roll(b, n_roll) |
| b[:n_roll] = b_init |
|
|
| |
| if keep_boundary: |
| axis.plot( |
| t_s, |
| b[:t_s.shape[0]], |
| color="black", |
| alpha=sample_hist_alpha, |
| zorder=1000 + zorder_cnt, |
| linewidth=lw_m, |
| linestyle=linestyle, |
| label=tmp_label, |
| ) |
|
|
| |
| if keep_starting_point: |
| axis.axvline( |
| x=sample['metadata']['t'][0], |
| ymin=-ylim, |
| ymax=ylim, |
| color="black", |
| linestyle=linestyle, |
| linewidth=lw_m, |
| alpha=sample_hist_alpha, |
| ) |
|
|
| |
| if keep_slope: |
| tmp_traj = sample["metadata"]["trajectory"] |
| n_traj_cols = tmp_traj.shape[1] if tmp_traj.ndim > 1 else 1 |
|
|
| for i in range(n_traj_cols): |
| tmp_maxid = np.minimum(np.argmax(np.where(tmp_traj[:, i] > -999)), t_s.shape[0]) |
|
|
| |
| axis.plot( |
| t_s[:tmp_maxid] + sample['metadata']['t'][0], |
| tmp_traj[:tmp_maxid, i], |
| color=color_dict[i], |
| linestyle=linestyle, |
| alpha=sample_hist_alpha, |
| zorder=1000 + zorder_cnt, |
| linewidth=lw_m, |
| ) |
|
|
| return b[0] |