Coverage for tasks/cardinal_expdetthreshold.py: 33%
234 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-08 23:14 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-08 23:14 +0000
1#!/usr/bin/env python
3"""
4camcops_server/tasks/cardinal_expdetthreshold.py
6===============================================================================
8 Copyright (C) 2012, University of Cambridge, Department of Psychiatry.
9 Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
11 This file is part of CamCOPS.
13 CamCOPS is free software: you can redistribute it and/or modify
14 it under the terms of the GNU General Public License as published by
15 the Free Software Foundation, either version 3 of the License, or
16 (at your option) any later version.
18 CamCOPS is distributed in the hope that it will be useful,
19 but WITHOUT ANY WARRANTY; without even the implied warranty of
20 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 GNU General Public License for more details.
23 You should have received a copy of the GNU General Public License
24 along with CamCOPS. If not, see <https://www.gnu.org/licenses/>.
26===============================================================================
28"""
30import math
31import logging
32from typing import List, Optional, Tuple, Type
34from cardinal_pythonlib.maths_numpy import inv_logistic, logistic
35import cardinal_pythonlib.rnc_web as ws
36from matplotlib.figure import Figure
37import numpy as np
38from sqlalchemy.sql.schema import Column
39from sqlalchemy.sql.sqltypes import Float, Integer, Text, UnicodeText
41from camcops_server.cc_modules.cc_constants import (
42 CssClass,
43 MatplotlibConstants,
44 PlotDefaults,
45)
46from camcops_server.cc_modules.cc_db import (
47 ancillary_relationship,
48 GenericTabletRecordMixin,
49 TaskDescendant,
50)
51from camcops_server.cc_modules.cc_html import get_yes_no_none, tr_qa
52from camcops_server.cc_modules.cc_request import CamcopsRequest
53from camcops_server.cc_modules.cc_sqla_coltypes import (
54 CamcopsColumn,
55 PendulumDateTimeAsIsoTextColType,
56)
57from camcops_server.cc_modules.cc_sqlalchemy import Base
58from camcops_server.cc_modules.cc_task import Task, TaskHasPatientMixin
59from camcops_server.cc_modules.cc_text import SS
61log = logging.getLogger(__name__)
64LOWER_MARKER = 0.25
65UPPER_MARKER = 0.75
66EQUATION_COMMENT = (
67 "logits: L(X) = intercept + slope * X; "
68 "probability: P = 1 / (1 + exp(-intercept - slope * X))"
69)
70MODALITY_AUDITORY = 0
71MODALITY_VISUAL = 1
72DP = 3
75# =============================================================================
76# CardinalExpDetThreshold
77# =============================================================================
80class CardinalExpDetThresholdTrial(
81 GenericTabletRecordMixin, TaskDescendant, Base
82):
83 __tablename__ = "cardinal_expdetthreshold_trials"
85 cardinal_expdetthreshold_id = Column(
86 "cardinal_expdetthreshold_id",
87 Integer,
88 nullable=False,
89 comment="FK to CardinalExpDetThreshold",
90 )
91 trial = Column(
92 "trial", Integer, nullable=False, comment="Trial number (0-based)"
93 )
95 # Results
96 trial_ignoring_catch_trials = Column(
97 "trial_ignoring_catch_trials",
98 Integer,
99 comment="Trial number, ignoring catch trials (0-based)",
100 )
101 target_presented = Column(
102 "target_presented", Integer, comment="Target presented? (0 no, 1 yes)"
103 )
104 target_time = Column(
105 "target_time",
106 PendulumDateTimeAsIsoTextColType,
107 comment="Target presentation time (ISO-8601)",
108 )
109 intensity = Column(
110 "intensity", Float, comment="Target intensity (0.0-1.0)"
111 )
112 choice_time = Column(
113 "choice_time",
114 PendulumDateTimeAsIsoTextColType,
115 comment="Time choice offered (ISO-8601)",
116 )
117 responded = Column(
118 "responded", Integer, comment="Responded? (0 no, 1 yes)"
119 )
120 response_time = Column(
121 "response_time",
122 PendulumDateTimeAsIsoTextColType,
123 comment="Time of response (ISO-8601)",
124 )
125 response_latency_ms = Column(
126 "response_latency_ms", Integer, comment="Response latency (ms)"
127 )
128 yes = Column(
129 "yes", Integer, comment="Subject chose YES? (0 didn't, 1 did)"
130 )
131 no = Column("no", Integer, comment="Subject chose NO? (0 didn't, 1 did)")
132 caught_out_reset = Column(
133 "caught_out_reset",
134 Integer,
135 comment="Caught out on catch trial, thus reset? (0 no, 1 yes)",
136 )
137 trial_num_in_calculation_sequence = Column(
138 "trial_num_in_calculation_sequence",
139 Integer,
140 comment="Trial number as used for threshold calculation",
141 )
143 @classmethod
144 def get_html_table_header(cls) -> str:
145 return f"""
146 <table class="{CssClass.EXTRADETAIL}">
147 <tr>
148 <th>Trial# (0-based)</th>
149 <th>Trial# (ignoring catch trials) (0-based)</th>
150 <th>Target presented?</th>
151 <th>Target time</th>
152 <th>Intensity</th>
153 <th>Choice time</th>
154 <th>Responded?</th>
155 <th>Response time</th>
156 <th>Response latency (ms)</th>
157 <th>Yes?</th>
158 <th>No?</th>
159 <th>Caught out (and reset)?</th>
160 <th>Trial# in calculation sequence</th>
161 </tr>
162 """
164 def get_html_table_row(self) -> str:
165 return ("<tr>" + "<td>{}</td>" * 13 + "</th>").format(
166 self.trial,
167 self.trial_ignoring_catch_trials,
168 self.target_presented,
169 self.target_time,
170 ws.number_to_dp(self.intensity, DP),
171 self.choice_time,
172 self.responded,
173 self.response_time,
174 self.response_latency_ms,
175 self.yes,
176 self.no,
177 ws.webify(self.caught_out_reset),
178 ws.webify(self.trial_num_in_calculation_sequence),
179 )
181 # -------------------------------------------------------------------------
182 # TaskDescendant overrides
183 # -------------------------------------------------------------------------
185 @classmethod
186 def task_ancestor_class(cls) -> Optional[Type["Task"]]:
187 return CardinalExpDetThreshold
189 def task_ancestor(self) -> Optional["CardinalExpDetThreshold"]:
190 return CardinalExpDetThreshold.get_linked(
191 self.cardinal_expdetthreshold_id, self
192 )
195class CardinalExpDetThreshold(TaskHasPatientMixin, Task):
196 """
197 Server implementation of the Cardinal_ExpDetThreshold task.
198 """
200 __tablename__ = "cardinal_expdetthreshold"
201 shortname = "Cardinal_ExpDetThreshold"
202 use_landscape_for_pdf = True
204 # Config
205 modality = Column(
206 "modality", Integer, comment="Modality (0 auditory, 1 visual)"
207 )
208 target_number = Column(
209 "target_number",
210 Integer,
211 comment="Target number (within available targets of that modality)",
212 )
213 background_filename = CamcopsColumn(
214 "background_filename",
215 Text,
216 exempt_from_anonymisation=True,
217 comment="Filename of media used for background",
218 )
219 target_filename = CamcopsColumn(
220 "target_filename",
221 Text,
222 exempt_from_anonymisation=True,
223 comment="Filename of media used for target",
224 )
225 visual_target_duration_s = Column(
226 "visual_target_duration_s", Float, comment="Visual target duration (s)"
227 )
228 background_intensity = Column(
229 "background_intensity",
230 Float,
231 comment="Intensity of background (0.0-1.0)",
232 )
233 start_intensity_min = Column(
234 "start_intensity_min",
235 Float,
236 comment="Minimum starting intensity (0.0-1.0)",
237 )
238 start_intensity_max = Column(
239 "start_intensity_max",
240 Float,
241 comment="Maximum starting intensity (0.0-1.0)",
242 )
243 initial_large_intensity_step = Column(
244 "initial_large_intensity_step",
245 Float,
246 comment="Initial, large, intensity step (0.0-1.0)",
247 )
248 main_small_intensity_step = Column(
249 "main_small_intensity_step",
250 Float,
251 comment="Main, small, intensity step (0.0-1.0)",
252 )
253 num_trials_in_main_sequence = Column(
254 "num_trials_in_main_sequence",
255 Integer,
256 comment="Number of trials required in main sequence",
257 )
258 p_catch_trial = Column(
259 "p_catch_trial", Float, comment="Probability of catch trial"
260 )
261 prompt = CamcopsColumn(
262 "prompt",
263 UnicodeText,
264 exempt_from_anonymisation=True,
265 comment="Prompt given to subject",
266 )
267 iti_s = Column("iti_s", Float, comment="Intertrial interval (s)")
269 # Results
270 finished = Column(
271 "finished",
272 Integer,
273 comment="Subject finished successfully (0 no, 1 yes)",
274 )
275 intercept = Column("intercept", Float, comment=EQUATION_COMMENT)
276 slope = Column("slope", Float, comment=EQUATION_COMMENT)
277 k = Column("k", Float, comment=EQUATION_COMMENT + "; k = slope")
278 theta = Column(
279 "theta",
280 Float,
281 comment=EQUATION_COMMENT + "; theta = -intercept/k = -intercept/slope",
282 )
284 # Relationships
285 trials = ancillary_relationship(
286 parent_class_name="CardinalExpDetThreshold",
287 ancillary_class_name="CardinalExpDetThresholdTrial",
288 ancillary_fk_to_parent_attr_name="cardinal_expdetthreshold_id",
289 ancillary_order_by_attr_name="trial",
290 ) # type: List[CardinalExpDetThresholdTrial]
292 @staticmethod
293 def longname(req: "CamcopsRequest") -> str:
294 _ = req.gettext
295 return _(
296 "Cardinal RN – Threshold determination for "
297 "Expectation–Detection task"
298 )
300 def is_complete(self) -> bool:
301 return bool(self.finished)
303 def _get_figures(
304 self, req: CamcopsRequest
305 ) -> Tuple[Figure, Optional[Figure]]:
306 """
307 Create and return figures. Returns ``trialfig, fitfig``.
308 """
309 trialarray = self.trials
311 # Constants
312 jitter_step = 0.02
313 dp_to_consider_same_for_jitter = 3
314 y_extra_space = 0.1
315 x_extra_space = 0.02
316 figsize = (
317 PlotDefaults.FULLWIDTH_PLOT_WIDTH / 2,
318 PlotDefaults.FULLWIDTH_PLOT_WIDTH / 2,
319 )
321 # Figure and axes
322 trialfig = req.create_figure(figsize=figsize)
323 trialax = trialfig.add_subplot(MatplotlibConstants.WHOLE_PANEL)
324 fitfig = None # type: Optional[Figure]
326 # Anything to do?
327 if not trialarray:
328 return trialfig, fitfig
330 # Data
331 notcalc_detected_x = []
332 notcalc_detected_y = []
333 notcalc_missed_x = []
334 notcalc_missed_y = []
335 calc_detected_x = []
336 calc_detected_y = []
337 calc_missed_x = []
338 calc_missed_y = []
339 catch_detected_x = []
340 catch_detected_y = []
341 catch_missed_x = []
342 catch_missed_y = []
343 all_x = []
344 all_y = []
345 for t in trialarray:
346 x = t.trial
347 y = t.intensity
348 all_x.append(x)
349 all_y.append(y)
350 if t.trial_num_in_calculation_sequence is not None:
351 if t.yes:
352 calc_detected_x.append(x)
353 calc_detected_y.append(y)
354 else:
355 calc_missed_x.append(x)
356 calc_missed_y.append(y)
357 elif t.target_presented:
358 if t.yes:
359 notcalc_detected_x.append(x)
360 notcalc_detected_y.append(y)
361 else:
362 notcalc_missed_x.append(x)
363 notcalc_missed_y.append(y)
364 else: # catch trial
365 if t.yes:
366 catch_detected_x.append(x)
367 catch_detected_y.append(y)
368 else:
369 catch_missed_x.append(x)
370 catch_missed_y.append(y)
372 # Create trialfig plots
373 trialax.plot(
374 all_x,
375 all_y,
376 marker=MatplotlibConstants.MARKER_NONE,
377 color=MatplotlibConstants.COLOUR_GREY_50,
378 linestyle=MatplotlibConstants.LINESTYLE_SOLID,
379 label=None,
380 )
381 trialax.plot(
382 notcalc_missed_x,
383 notcalc_missed_y,
384 marker=MatplotlibConstants.MARKER_CIRCLE,
385 color=MatplotlibConstants.COLOUR_BLACK,
386 linestyle=MatplotlibConstants.LINESTYLE_NONE,
387 label="miss",
388 )
389 trialax.plot(
390 notcalc_detected_x,
391 notcalc_detected_y,
392 marker=MatplotlibConstants.MARKER_PLUS,
393 color=MatplotlibConstants.COLOUR_BLACK,
394 linestyle=MatplotlibConstants.LINESTYLE_NONE,
395 label="hit",
396 )
397 trialax.plot(
398 calc_missed_x,
399 calc_missed_y,
400 marker=MatplotlibConstants.MARKER_CIRCLE,
401 color=MatplotlibConstants.COLOUR_RED,
402 linestyle=MatplotlibConstants.LINESTYLE_NONE,
403 label="miss, scored",
404 )
405 trialax.plot(
406 calc_detected_x,
407 calc_detected_y,
408 marker=MatplotlibConstants.MARKER_PLUS,
409 color=MatplotlibConstants.COLOUR_BLUE,
410 linestyle=MatplotlibConstants.LINESTYLE_NONE,
411 label="hit, scored",
412 )
413 trialax.plot(
414 catch_missed_x,
415 catch_missed_y,
416 marker=MatplotlibConstants.MARKER_CIRCLE,
417 color=MatplotlibConstants.COLOUR_GREEN,
418 linestyle=MatplotlibConstants.LINESTYLE_NONE,
419 label="CR",
420 )
421 trialax.plot(
422 catch_detected_x,
423 catch_detected_y,
424 marker=MatplotlibConstants.MARKER_STAR,
425 color=MatplotlibConstants.COLOUR_GREEN,
426 linestyle=MatplotlibConstants.LINESTYLE_NONE,
427 label="FA",
428 )
429 leg = trialax.legend(
430 numpoints=1,
431 fancybox=True, # for set_alpha (below)
432 loc="best", # bbox_to_anchor=(0.75, 1.05)
433 labelspacing=0,
434 handletextpad=0,
435 prop=req.fontprops,
436 )
437 leg.get_frame().set_alpha(0.5)
438 trialax.set_xlabel("Trial number (0-based)", fontdict=req.fontdict)
439 trialax.set_ylabel("Intensity", fontdict=req.fontdict)
440 trialax.set_ylim(0 - y_extra_space, 1 + y_extra_space)
441 trialax.set_xlim(-0.5, len(trialarray) - 0.5)
442 req.set_figure_font_sizes(trialax)
444 # Anything to do for fitfig?
445 if self.k is None or self.theta is None:
446 return trialfig, fitfig
448 # Create fitfig
449 fitfig = req.create_figure(figsize=figsize)
450 fitax = fitfig.add_subplot(MatplotlibConstants.WHOLE_PANEL)
451 detected_x = []
452 detected_x_approx = []
453 detected_y = []
454 missed_x = []
455 missed_x_approx = []
456 missed_y = []
457 all_x = []
458 for t in trialarray:
459 if t.trial_num_in_calculation_sequence is not None:
460 all_x.append(t.intensity)
461 approx_x = f"{t.intensity:.{dp_to_consider_same_for_jitter}f}"
462 if t.yes:
463 detected_y.append(
464 1 - detected_x_approx.count(approx_x) * jitter_step
465 )
466 detected_x.append(t.intensity)
467 detected_x_approx.append(approx_x)
468 else:
469 missed_y.append(
470 0 + missed_x_approx.count(approx_x) * jitter_step
471 )
472 missed_x.append(t.intensity)
473 missed_x_approx.append(approx_x)
475 # Again, anything to do for fitfig?
476 if not all_x:
477 return trialfig, fitfig
479 fit_x = np.arange(0.0 - x_extra_space, 1.0 + x_extra_space, 0.001)
480 fit_y = logistic(fit_x, self.k, self.theta)
481 fitax.plot(
482 fit_x,
483 fit_y,
484 color=MatplotlibConstants.COLOUR_GREEN,
485 linestyle=MatplotlibConstants.LINESTYLE_SOLID,
486 )
487 fitax.plot(
488 missed_x,
489 missed_y,
490 marker=MatplotlibConstants.MARKER_CIRCLE,
491 color=MatplotlibConstants.COLOUR_RED,
492 linestyle=MatplotlibConstants.LINESTYLE_NONE,
493 )
494 fitax.plot(
495 detected_x,
496 detected_y,
497 marker=MatplotlibConstants.MARKER_PLUS,
498 color=MatplotlibConstants.COLOUR_BLUE,
499 linestyle=MatplotlibConstants.LINESTYLE_NONE,
500 )
501 fitax.set_ylim(0 - y_extra_space, 1 + y_extra_space)
502 fitax.set_xlim(
503 np.amin(all_x) - x_extra_space, np.amax(all_x) + x_extra_space
504 )
505 marker_points = []
506 for y in (LOWER_MARKER, 0.5, UPPER_MARKER):
507 x = inv_logistic(y, self.k, self.theta)
508 marker_points.append((x, y))
509 for p in marker_points:
510 fitax.plot(
511 [p[0], p[0]], # x
512 [-1, p[1]], # y
513 color=MatplotlibConstants.COLOUR_GREY_50,
514 linestyle=MatplotlibConstants.LINESTYLE_DOTTED,
515 )
516 fitax.plot(
517 [-1, p[0]], # x
518 [p[1], p[1]], # y
519 color=MatplotlibConstants.COLOUR_GREY_50,
520 linestyle=MatplotlibConstants.LINESTYLE_DOTTED,
521 )
522 fitax.set_xlabel("Intensity", fontdict=req.fontdict)
523 fitax.set_ylabel(
524 "Detected? (0=no, 1=yes; jittered)", fontdict=req.fontdict
525 )
526 req.set_figure_font_sizes(fitax)
528 # Done
529 return trialfig, fitfig
531 def get_trial_html(self, req: CamcopsRequest) -> str:
532 """
533 Note re plotting markers without lines:
535 .. code-block:: python
537 import matplotlib.pyplot as plt
539 fig, ax = plt.subplots()
540 ax.plot([1, 2], [1, 2], marker="+", color="r", linestyle="-")
541 ax.plot([1, 2], [2, 1], marker="o", color="b", linestyle="None")
542 fig.savefig("test.png")
543 # ... the "absent" line does NOT "cut" the red one.
545 Args:
546 req:
548 Returns:
550 """
551 trialarray = self.trials
552 html = CardinalExpDetThresholdTrial.get_html_table_header()
553 for t in trialarray:
554 html += t.get_html_table_row()
555 html += """</table>"""
557 # Don't add figures if we're incomplete
558 if not self.is_complete():
559 return html
561 # Add figures
562 trialfig, fitfig = self._get_figures(req)
564 html += f"""
565 <table class="{CssClass.NOBORDER}">
566 <tr>
567 <td class="{CssClass.NOBORDERPHOTO}">
568 {req.get_html_from_pyplot_figure(trialfig)}
569 </td>
570 <td class="{CssClass.NOBORDERPHOTO}">
571 {req.get_html_from_pyplot_figure(fitfig)}
572 </td>
573 </tr>
574 </table>
575 """
577 return html
579 def logistic_x_from_p(self, p: Optional[float]) -> Optional[float]:
580 try:
581 return (math.log(p / (1 - p)) - self.intercept) / self.slope
582 except (TypeError, ValueError):
583 return None
585 def get_task_html(self, req: CamcopsRequest) -> str:
586 if self.modality == MODALITY_AUDITORY:
587 modality = req.sstring(SS.AUDITORY)
588 elif self.modality == MODALITY_VISUAL:
589 modality = req.sstring(SS.VISUAL)
590 else:
591 modality = None
592 h = f"""
593 <div class="{CssClass.SUMMARY}">
594 <table class="{CssClass.SUMMARY}">
595 {self.get_is_complete_tr(req)}
596 </table>
597 </div>
598 <div class="{CssClass.EXPLANATION}">
599 The ExpDet-Threshold task measures visual and auditory
600 thresholds for stimuli on a noisy background, using a
601 single-interval up/down method. It is intended as a prequel to
602 the Expectation–Detection task.
603 </div>
604 <table class="{CssClass.TASKCONFIG}">
605 <tr>
606 <th width="50%">Configuration variable</th>
607 <th width="50%">Value</th>
608 </tr>
609 """
610 h += tr_qa("Modality", modality)
611 h += tr_qa("Target number", self.target_number)
612 h += tr_qa("Background filename", ws.webify(self.background_filename))
613 h += tr_qa("Background intensity", self.background_intensity)
614 h += tr_qa("Target filename", ws.webify(self.target_filename))
615 h += tr_qa(
616 "(For visual targets) Target duration (s)",
617 self.visual_target_duration_s,
618 )
619 h += tr_qa("Start intensity (minimum)", self.start_intensity_min)
620 h += tr_qa("Start intensity (maximum)", self.start_intensity_max)
621 h += tr_qa(
622 "Initial (large) intensity step", self.initial_large_intensity_step
623 )
624 h += tr_qa(
625 "Main (small) intensity step", self.main_small_intensity_step
626 )
627 h += tr_qa(
628 "Number of trials in main sequence",
629 self.num_trials_in_main_sequence,
630 )
631 h += tr_qa("Probability of a catch trial", self.p_catch_trial)
632 h += tr_qa("Prompt", self.prompt)
633 h += tr_qa("Intertrial interval (ITI) (s)", self.iti_s)
634 h += f"""
635 </table>
636 <table class="{CssClass.TASKDETAIL}">
637 <tr><th width="50%">Measure</th><th width="50%">Value</th></tr>
638 """
639 h += tr_qa("Finished?", get_yes_no_none(req, self.finished))
640 h += tr_qa("Logistic intercept", ws.number_to_dp(self.intercept, DP))
641 h += tr_qa("Logistic slope", ws.number_to_dp(self.slope, DP))
642 h += tr_qa("Logistic k (= slope)", ws.number_to_dp(self.k, DP))
643 h += tr_qa(
644 "Logistic theta (= –intercept/slope)",
645 ws.number_to_dp(self.theta, DP),
646 )
647 h += tr_qa(
648 f"Intensity for {100 * LOWER_MARKER}% detection",
649 ws.number_to_dp(self.logistic_x_from_p(LOWER_MARKER), DP),
650 )
651 h += tr_qa(
652 "Intensity for 50% detection", ws.number_to_dp(self.theta, DP)
653 )
654 h += tr_qa(
655 f"Intensity for {100 * UPPER_MARKER}% detection",
656 ws.number_to_dp(self.logistic_x_from_p(UPPER_MARKER), DP),
657 )
658 h += """
659 </table>
660 """
661 h += self.get_trial_html(req)
662 return h