Skip to content

Add JSD kernel #264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Oct 1, 2024
Merged

Add JSD kernel #264

merged 23 commits into from
Oct 1, 2024

Conversation

Tcc0403
Copy link
Collaborator

@Tcc0403 Tcc0403 commented Sep 21, 2024

Summary

Resolve #252

Details

JSD

We expect input $X$ and target $Y$ are distributions in log-space, i.e., $X = log Q$ and $Y = log P$.
Jenson-Shannon Divergence between two distributions $P$ and $Q$ is defined as:

$$JSD(X, Y) = JSD(P\ \Vert \ Q) = \frac{1}{2} (KL(P\ \Vert\ M) + KL(Q\ \Vert\ M))$$

where $M = \frac{1}{2}(P + Q)$ is the average distribution and $KL$ is the Kullback-Leibler divergence.
Given that $X = log Q$ and $Y = log P$, we can simplify JSD expression to:

$$\begin{align} JSD(X, Y) &= \frac{1}{2}(\sum_i P_i\ log\frac{P_i}{M_i} + \sum_i Q_i\ log\frac{Q_i}{M_i})\\ &= \frac{1}{2} \sum_i (P_i\ log\ P_i - P_i\ log\ M_i + Q_i\ log\ Q_i - Q_i\ log\ M_i)\\ &= \frac{1}{2} \sum_i (P_i\ log\ P_i + Q_i\ log\ Q_i - 2M_i\ log\ M_i)\\ &= \frac{1}{2} \sum_i (P_i \cdot X_i + Q_i\cdot Y_i - 2M_i\ log\ M_i) \end{align}$$

We define the point-wise JSD as:

$$JSD(X_i, Y_i)= \frac{1}{2} (P_i \cdot X_i + Q_i\cdot Y_i - 2M_i\ log\ M_i)$$

With point-wise JSD, it's easier to implement JSDs with respect to different reduction methods in future.
The only downside is that it creates a torch.float32 tensor with the same shape as input's.

Current implementation is hardcoded to batchmean which is the original JSD definition.

Gradients

Given:

$$JSD(X, Y) = JSD(P\ \Vert \ Q) = \frac{1}{2} (KL(P\ \Vert\ M) + KL(Q\ \Vert\ M))$$

where $Q = e^X$, $P = e^Y$, and $M = \frac{1}{2}(e^X + e^Y)$.

Gradients of $KL(P\ \Vert\ M)$ with respect to $X_i$:

$$\begin{align} \frac{\partial}{\partial X_i} \sum_j P_j\ log\frac{P_j}{M_j} &= \frac{\partial}{\partial X_i} \sum_j P_j (Y_j - log\ M_j)\\ &= \frac{\partial}{\partial X_i} \sum_j - P_j log\ M_j\\ &= - P_i \cdot \frac{1}{M_i}\cdot \frac{e^X}{2} = -P_i\cdot \frac{Q_i}{2M_i} \end{align}$$

Gradients of $KL(Q\ \Vert\ M)$ with respect to $X_i$:

$$\begin{align} \frac{\partial}{\partial X_i} \sum_j Q_j\ log\frac{Q_j}{M_j} &= \frac{\partial}{\partial X_i} \sum_j Q_j (X_j - log\ M_j)\\ &= \sum_j\left( \frac{\partial Q_j}{\partial X_i}(X_j - log\ M_j) + Q_j \frac{\partial (X_j - log\ M_j)}{\partial X_i}\right)\\ &= Q_i(X_i - log\ M_i) + Q_i (1 - \frac{Q_i}{2M_i}) \end{align}$$

Final gradients of JSD:

Combine the results from two KL divergence terms:

$$\begin{align} \frac{\partial}{\partial X_i} JSD(X, Y) &= \frac{1}{2}\left(-P_i\cdot \frac{Q_i}{2M_i} + Q_i(X_i - log\ M_i) + Q_i (1 - \frac{Q_i}{2M_i})\right) \end{align}$$

Simplify this to:

$$\begin{align} \frac{\partial}{\partial X_i} JSD(X, Y) &= \frac{1}{2}\left(-P_i\cdot \frac{Q_i}{2M_i} + Q_i(X_i - log\ M_i) + Q_i (1 - \frac{Q_i}{2M_i})\right)\\ &= \frac{1}{2}\left(Q_i (X_i - log\ M_i + 1 - \frac{P_i + Q_i}{2M_i})\right),\ where\ 2M_i = P_i + Q_i \\ &= \frac{1}{2}\cdot Q_i \cdot (X_i - log\ M_i) \end{align}$$

We store gradients at X_ptr in forward pass to save memory, then retrieve it through ctx in backward function as cross_entropy does. (inplace)

note: inplace operations on inputs might cause an issue with gradient computation.

Testing Done

With inplace (Storing gradients to inputs)

reduce memory usage by 61.54%
jsd_memory
increase speed by 53.64%
jsd_speed

Without inplace

reduce memory usage by 53%
jsd_memory
increase speed by 61%
jsd_speed

  • Hardware Type: H100
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Sep 22, 2024

@yundai424
I got almost everything done excpet for minor clearups and comments that will be handled asap. Do I need to add any parameters, such as reduction? Since there's not much references I can look up, I have no idea what kind of parameters that might be useful.

btw, should I change jsd to js_div?

@lancerts
Copy link
Collaborator

@yundai424 I got almost everything done excpet for minor clearups and comments that will be handled asap. Do I need to add any parameters, such as reduction? Since there's not much references I can look up, I have no idea what kind of parameters that might be useful.

btw, should I change jsd to js_div?

jsd is fne~

@lancerts lancerts requested a review from yundai424 September 22, 2024 17:01
@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Sep 23, 2024

Added jsd benchmark script

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Sep 25, 2024

GPU CI failing on irrelevant tests

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Sep 25, 2024

pushed an implementation without inplace operations on input. see #262 (comment). we can decide which one to use in future.

Copy link
Collaborator

@qingquansong qingquansong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the efforts! Looks good to me in general and I'm assuming the log scale input is for considering log_softmax as input for common cases which is more numerically stable? (One minor useful feature we can add but could be a future pr is the ignore index part (suppose we can provide a extra input label tensor for it but might be too specific since general JSD does not have hard label as input). )

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Sep 27, 2024

I'm assuming the log scale input is for considering log_softmax as input for common cases which is more numerically stable?

yes, and torch.KLDivLoss also takes input in the log-space. I think it would be better to have similar arguments for users.

Copy link
Collaborator

@yundai424 yundai424 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the contribution! @qingquansong could you help to create two follow up issues for 1) adding ignore index support for divergence losses, and 2) add general JSD (w/ beta) support? Thanks a lot!!

Copy link
Collaborator

@qingquansong qingquansong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let me know when the CI test got fixed and I can give a quick shipping stamp. Thanks!

@lancerts lancerts enabled auto-merge (squash) October 1, 2024 21:21
@lancerts lancerts merged commit 8e2f3a4 into linkedin:main Oct 1, 2024
2 checks passed
@Tcc0403 Tcc0403 deleted the jsd branch December 1, 2024 03:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

reverse KL and JSD
4 participants