Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Gaussian Mixture OT #649

Merged
merged 38 commits into from
Jul 30, 2024
Merged

[MRG] Gaussian Mixture OT #649

merged 38 commits into from
Jul 30, 2024

Conversation

eloitanguy
Copy link
Collaborator

@eloitanguy eloitanguy commented Jun 27, 2024

TODO

  • make proj_SDP batch-friendy and backend-compatible
  • finish flow example
  • re-implement pdfs
  • documentation
  • faster bures distance matrix (re-implement bures and store square roots): einsum?
  • faster barycentric mapping of GMM plan (same idea as for bures)
  • stochastic GMM mapping
  • asserts / tests
  • ref to Delon-Desolneux
  • fix Issue Mean computed without weights in empirical gaussain OT #648
  • add nx.det
  • test nx.det
  • gaussian_pdf and gmm_pdf backends
  • GMM plan and maps example (linting WIP)
  • ot.plot.plot1D_mat update
  • test ot.plot
  • batched nx.sqrtm
  • einsum magic for ot.gmm additions (hopefully)

Types of changes

  • Added ot.gmm for Gaussian Mixture Model OT, with a GMM plan example and a GMM flow example
  • Fixed Issue #648 in ot.gaussian
  • added ot.utils.proj_SDP for a broadcastable and backend-compatible projection of a symmetric matrix onto the SDP cone
  • added nx.det
  • added (..., d, d) broadcastability for nx.sqrtm
  • re-wrote ot.plot.proj1D_mat with the same API with visual improvements

Motivation and context / Related issue

new features implementing (GMM OT)[https://arxiv.org/pdf/1907.05254] + ot.plot1D_map improvement

How has this been tested (if it applies)

pystest full coverage + two example in the gallery

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Copy link

codecov bot commented Jun 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.78%. Comparing base (47c5925) to head (98a7fba).
Report is 14 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #649      +/-   ##
==========================================
+ Coverage   96.71%   96.78%   +0.07%     
==========================================
  Files          86       88       +2     
  Lines       17148    17502     +354     
==========================================
+ Hits        16584    16940     +356     
+ Misses        564      562       -2     

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments, but this is very good work.

ot/gmm.py Outdated
Cs12 = nx.sqrtm(C_s[i])
Cs12inv = nx.inv(Cs12)

for j in range(k_t):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

challenge 2 (optional) :here the second loop should be doable with einsum and batch sqrtm too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for both maps, I ended up doing a loop over all nonzero indices in the plan, I don't know if this is better than a smart einsum.

# i and j, b[i, j] is the translation part
rng = np.random.RandomState(seed)

A = nx.zeros((k_s, k_t, d, d))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a BIIG tensor!, again mayb can be done with einsum

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looping over nonzero entries in the plan is probably faster for this function, since there are at most k_s + k_t - 1 nonzero entries. I'll look into that

ot/plot.py Outdated


def plot1D_mat(a, b, M, title=''):
r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution
def plot1D_mat(a, b, M, title='', a_label='Source distribution',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the big and necessary update but keep in mind that yteh function must keep a very similar behavior by default than teh previous version.

That is the color and source/target should be the same as previous version . It is OK to have parameter that change this default behavior and you can use the values you want in your examples but this update should not change too much the other examples that use the function (except for the fill of the densities it is much better than the previous plot).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the default behaviour basically identical to the old ones (just quality upgrades, everything is in the same place), and the 'xy' version can be used with plot_style which defaults to yx. Colours also default to blue and red :)

@rflamary rflamary changed the title [WIP] Gaussian Mixture OT [MRG] Gaussian Mixture OT Jul 30, 2024
@rflamary rflamary merged commit 36c9252 into PythonOT:master Jul 30, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Mean computed without weights in empirical gaussain OT
4 participants