
class fairgrad.torch.cross_entropy.CrossEntropyLoss(reduction='mean', fairness_measure=None, y_train=None, s_train=None, y_desirable=[1], epsilon=0.0, fairness_rate=0.01, **kwargs)

This is an extension of the CrossEntropyLoss provided by pytorch. Please check pytorch documentation for understanding the cross entropy loss.

  • reduction (string, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken, 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: 'mean'

  • fairness_measure (string, FairnessMeasure) – Currently supported are “equal_odds”, “equal_opportunity”, “demographic_parity”, and “accuracy_parity”.

  • y_train (np.asarray[int], Tensor, optional) – All train example’s corresponding label

  • s_train (np.asarray[int], Tensor, optional) – All train example’s corresponding sensitive attribute. This means if there are 2 sensitive attributes, with each of them being binary. For instance gender - (male and female) and age (above 45, below 45). Total unique sentive attributes are 4.

  • y_desirable (np.asarray[int], Tensor, optional) – All desirable labels, only used with equality of opportunity.

  • epsilon (float, optional) – The slack which is allowed for the final fairness level.

  • fairness_rate (float, optional) – Parameter which intertwines current fairness weights with sum of previous fairness rates.

  • **kwargs – Arbitrary keyword arguments passed to CrossEntropyLoss upon instantiation. Using is at your own risk as it might result in unexpected behaviours.


>>> input = torch.randn(10, 5, requires_grad=True)
>>> target = torch.empty(10, dtype=torch.long).random_(2)
>>> s = torch.empty(10, dtype=torch.long).random_(2) # protected attribute
>>> loss = CrossEntropyLoss(y_train = target, s_train = s, fairness_measure = 'equal_odds')
>>> output = loss(input, target, s, mode='train')
>>> output.backward()