Coverage for metrics.py : 100%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch
2from torch import Tensor
5class CosineLoss(torch.nn.CosineSimilarity):
6 """CosineLoss Implements a simple cosine similarity based loss."""
8 def __init__(self, *args, **kwargs) -> None:
9 """__init__ Instantiates the class.
11 All arguments are passed to `torch.nn.CosineSimilarity`
12 """
13 super().__init__(*args, **kwargs)
15 def forward(self, truth: Tensor, prediction: Tensor) -> Tensor:
16 """Forward calculates the loss.
18 Parameters
19 ----------
20 truth : Tensor
21 prediction : Tensor
23 Returns
24 -------
25 Tensor
27 Examples
28 --------
29 >>> loss = CosineLoss(dim=1, eps=1e-4)
30 >>> loss(torch.ones([1,2,5]), torch.zeros([1,2,5]))
31 tensor([[1., 1., 1., 1., 1.]])
32 >>> loss(torch.ones([1,2,5]), 5*torch.zeros([1,2,5]))
33 tensor([[1., 1., 1., 1., 1.]])
34 >>> loss(torch.zeros([1,2,5]), torch.zeros([1,2,5]))
35 tensor([[0., 0., 0., 0., 0.]])
36 """
37 out = super().forward(truth, prediction)
38 out = 1 - out
39 return out
42class PearsonCorrelation(torch.nn.Module):
43 """PearsonCorrelation Implements a simple pearson correlation."""
45 def __init__(self, axis=1, eps=1e-4):
46 """__init__ Instantiates the class.
48 Creates a callable object to calculate the pearson correlation on an axis
50 Parameters
51 ----------
52 axis : int, optional
53 The axis over which the correlation is calculated.
54 For instance, if the input has shape [5, 500] and the axis is set
55 to 1, the output will be of shape [5]. On the other hand, if the axis
56 is set to 0, the output will have shape [500], by default 1
57 eps : float, optional
58 Number to be added to to prevent division by 0, by default 1e-4
59 """
60 super().__init__()
61 self.axis = axis
62 self.eps = eps
64 def forward(self, x, y):
65 """Forward calculates the loss.
67 Parameters
68 ----------
69 truth : Tensor
70 prediction : Tensor
72 Returns
73 -------
74 Tensor
76 Examples
77 --------
78 >>> loss = PearsonCorrelation(axis=1, eps=1e-4)
79 >>> loss(torch.ones([1,2,5]), torch.zeros([1,2,5]))
80 tensor([[1., 1., 1., 1., 1.]])
81 >>> loss(torch.ones([1,2,5]), 5*torch.zeros([1,2,5]))
82 tensor([[1., 1., 1., 1., 1.]])
83 >>> loss(torch.zeros([1,2,5]), torch.zeros([1,2,5]))
84 tensor([[0., 0., 0., 0., 0.]])
85 >>> out = loss(torch.rand([5, 174]), torch.rand([5, 174]))
86 >>> out.shape
87 torch.Size([5])
88 >>> loss = PearsonCorrelation(axis=0, eps=1e-4)
89 >>> out = loss(torch.rand([5, 174]), torch.rand([5, 174]))
90 >>> out.shape
91 torch.Size([174])
92 """
93 vx = x - torch.mean(x, axis=self.axis).unsqueeze(self.axis)
94 vy = y - torch.mean(y, axis=self.axis).unsqueeze(self.axis)
96 num = torch.sum(vx * vy, axis=self.axis)
97 denom_1 = torch.sqrt(torch.sum(vx ** 2, axis=self.axis))
98 denom_2 = torch.sqrt(torch.sum(vy ** 2, axis=self.axis))
99 denom = (denom_1 * denom_2) + self.eps
100 cost = num / denom
101 return cost