Skip to content

Commit 73da5c4

Browse files
committedJan 9, 2020
sinkhorn bindings
1 parent 7548704 commit 73da5c4

14 files changed

+100
-51
lines changed
 

‎CMakeLists.txt

+3-4
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ message(STATUS "Python includes at ${Python_INCLUDE_DIRS} ${Python_NumPy_INCLUDE
1717
set(EntropicMFG_HEADERS
1818
Kernels.h
1919
Utils.h
20-
MultiSinkhorn.h
2120
KLOperator.h
2221
CongestionOperator.h
2322
RectangularGridLaplacian.h
24-
MessagePassing.h)
23+
MessagePassing.h
24+
MultiSinkhorn.h)
2525

2626
set(EntropicMFG_SOURCES
2727
MultiSinkhorn.cpp
@@ -32,8 +32,7 @@ add_library(entropicmfg SHARED ${EntropicMFG_SOURCES})
3232
target_include_directories(entropicmfg PUBLIC .)
3333
target_link_libraries(entropicmfg Eigen3::Eigen)
3434
target_link_libraries(entropicmfg OpenMP::OpenMP_CXX)
35-
set_target_properties(entropicmfg PROPERTIES
36-
LINKER_LANGUAGE CXX)
35+
set_target_properties(entropicmfg PROPERTIES LINKER_LANGUAGE CXX)
3736
target_compile_options(entropicmfg PRIVATE -fPIC)
3837

3938

‎CongestionOperator.h

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <omp.h>
34
#include <Eigen/Dense>
45
#include "KLOperator.h"
56

‎KLOperator.h

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <memory>
34
#include <Eigen/Core>
45

56

@@ -15,4 +16,6 @@ class BaseProximalOperator {
1516
virtual MatrixXd operator()(const MatrixXd& x) const = 0;
1617
};
1718

19+
typedef std::shared_ptr<BaseProximalOperator> ProxPtr;
20+
1821
} // namespace klprox

‎MessagePassing.cpp

+10-8
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,32 @@
11
#include "MessagePassing.h"
22
#include "Kernels.h"
33

4+
#include <omp.h>
45
#include <vector>
56
#include <Eigen/Core>
67

78
using kernels::KernelPtr;
89
using namespace Eigen;
910

1011

11-
namespace algorithms
12+
namespace messages
1213
{
1314

1415

15-
MatrixXd contract(std::vector<Ref<MatrixXd>>& potentials, size_t idx, KernelPtr ker) {
16+
MatrixXd contract(std::vector<Ref<MatrixXd>>& potentials, const int idx, KernelPtr ker) {
1617
size_t nx = potentials[0].rows();
1718
size_t ny = potentials[0].cols();
19+
1820
MatrixXd A_ = MatrixXd::Ones(nx, ny);
19-
for (size_t k=0; k < idx - 1; k++) {
20-
A_ = (*ker)(potentials[k].array() * A_.array());
21+
for (int k=0; k < idx; k++) {
22+
A_ = ker->operator()(potentials[k].array() * A_.array());
2123
}
2224

25+
2326
MatrixXd B_ = MatrixXd::Ones(nx, ny);
24-
for (size_t k=potentials.size()-1; k > idx + 1; k--) {
25-
B_ = (*ker)(potentials[k].array() * A_.array());
27+
for (int k=potentials.size()-1; k > idx; k--) {
28+
B_ = ker->operator()(potentials[k].array() * B_.array());
2629
}
27-
2830
return A_.array() * B_.array();
2931
}
3032

@@ -35,7 +37,7 @@ std::vector<MatrixXd> compute_marginals(std::vector<Ref<MatrixXd>>& potentials,
3537
std::vector<MatrixXd> result(num_marginals);
3638

3739
#pragma omp parallel for
38-
for (size_t i=0; i < num_marginals; i++) {
40+
for (int i=0; i < num_marginals; i++) {
3941
result[i] = potentials[i].array() * contract(potentials, i, ker).array();
4042
}
4143

‎MessagePassing.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace messages
2121
* @return
2222
*/
2323
MatrixXd contract(std::vector<Ref<MatrixXd>>& potentials,
24-
const size_t idx, KernelPtr ker);
24+
const int idx, KernelPtr ker);
2525

2626

2727
std::vector<MatrixXd> compute_marginals(std::vector<Ref<MatrixXd>>& potentials, KernelPtr ker);

‎MultiSinkhorn.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ void MultimarginalSinkhorn::iterate(std::vector<Ref<MatrixXd>>& potentials) {
1414

1515
ArrayXXd conv;
1616

17-
1817
conv = contract(potentials, 0, kernel);
1918
potentials[0] = rho_0.array() / conv;
2019

@@ -28,13 +27,13 @@ void MultimarginalSinkhorn::iterate(std::vector<Ref<MatrixXd>>& potentials) {
2827

2928
}
3029

31-
void MultimarginalSinkhorn::run_sinkhorn(std::vector<Ref<MatrixXd>>& potentials, int num_iterations) {
30+
void MultimarginalSinkhorn::run(std::vector<Ref<MatrixXd>>& potentials, int num_iterations) {
3231
for (int i = 0; i < num_iterations; i++) {
3332
this->iterate(potentials);
3433
}
3534
}
3635

37-
inline std::vector<MatrixXd> MultimarginalSinkhorn::get_marginals(std::vector<Ref<MatrixXd>>& potentials) {
36+
std::vector<MatrixXd> MultimarginalSinkhorn::get_marginals(std::vector<Ref<MatrixXd>>& potentials) {
3837
return compute_marginals(potentials, kernel);
3938
}
4039

‎MultiSinkhorn.h

+13-8
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,37 @@ using std::shared_ptr;
1919
namespace sinkhorn {
2020

2121
using klprox::BaseProximalOperator;
22+
using klprox::ProxPtr;
2223

2324
class MultimarginalSinkhorn {
2425
private:
2526
/// Running cost proximal operator
26-
shared_ptr<BaseProximalOperator> running;
27+
ProxPtr running;
2728
/// Terminal cost proximal operator
28-
shared_ptr<BaseProximalOperator> terminal;
29+
ProxPtr terminal;
2930
/// Initial marginal
3031
MatrixXd rho_0;
3132
/// Kernel
3233
kernels::KernelPtr kernel;
3334

3435

3536
public:
37+
void setInitialDistribution(Ref<const MatrixXd>& rho_) {
38+
rho_0 = rho_;
39+
}
3640
MultimarginalSinkhorn(
37-
shared_ptr<BaseProximalOperator> running,
38-
shared_ptr<BaseProximalOperator> terminal,
39-
kernels::KernelPtr kernel):
40-
running(running), terminal(terminal), kernel(kernel) {}
41+
ProxPtr running,
42+
ProxPtr terminal,
43+
kernels::KernelPtr kernel,
44+
MatrixXd& rho):
45+
running(running), terminal(terminal), kernel(kernel), rho_0(rho) {}
4146

4247
/// Perform one iterate of the multimarginal Sinkhorn algorithm.
4348
void iterate(std::vector<Ref<MatrixXd>>& potentials);
4449

45-
void run_sinkhorn(std::vector<Ref<MatrixXd>>& potentials, int num_iterations);
50+
void run(std::vector<Ref<MatrixXd>>& potentials, int num_iterations);
4651

47-
inline std::vector<MatrixXd> get_marginals(std::vector<Ref<MatrixXd>>& potentials);
52+
std::vector<MatrixXd> get_marginals(std::vector<Ref<MatrixXd>>& potentials);
4853
};
4954

5055
}

‎examples/crowdmodel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int main(int argc, char* argv[]) {
5252

5353
shared_ptr<CongestionObstacleProx> terminal(running);
5454

55-
sinkhorn::MultimarginalSinkhorn sink(running, terminal, kernel_ptr);
55+
sinkhorn::MultimarginalSinkhorn sink(running, terminal, kernel_ptr, rho_0);
5656

5757

5858
return 1;

‎python/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ set(SOURCES
1010
pybind11_add_module(pyentropicmfg ${SOURCES})
1111
target_include_directories(pyentropicmfg PRIVATE ../)
1212
target_link_libraries(pyentropicmfg PRIVATE entropicmfg)
13+
target_link_libraries(pyentropicmfg PRIVATE OpenMP::OpenMP_CXX)
1314
target_compile_options(pyentropicmfg PRIVATE -fPIC)

‎python/bindings.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
#include <pybind11/pybind11.h>
2+
#include <pybind11/eigen.h>
3+
4+
#include <omp.h>
25

36

47
namespace py = pybind11;
58

69

710
void bind_prox(py::module &);
8-
void bind_sinkhorn(py::module &);
911
void bind_kernels(py::module &);
12+
void bind_sinkhorn(py::module &);
1013

1114

1215
PYBIND11_MODULE(pyentropicmfg, m) {
1316
m.doc() = "A library to solve variational mean-field games using an entropy "
1417
"minimization approach.";
1518

1619
bind_prox(m);
17-
bind_sinkhorn(m);
1820
bind_kernels(m);
21+
bind_sinkhorn(m);
1922

2023
}
2124

‎python/examples/congestion.py

+44-17
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
xg, yg = np.meshgrid(xar, xar)
1616
extent = [xg.min(), xg.max(), yg.min(), yg.max()]
1717

18-
epsilon = 0.001
19-
kernel = mfg.kernels.EuclideanKernel(
20-
nx, nx, 0, 1, 0, 1, epsilon)
2118

2219

2320
# Define domain mask
@@ -39,15 +36,9 @@ def mask_to_img(mask: np.ndarray):
3936
rho_0 /= rho_0.sum()
4037

4138
fig = plt.figure()
42-
plt.subplot(1,2,1)
4339
plt.imshow(mask_img, zorder=2)
4440
plt.imshow(rho_0, cmap=plt.cm.Blues, zorder=0, origin='lower')
4541

46-
plt.subplot(1,2,2)
47-
plt.imshow(mask_img, zorder=2)
48-
plt.imshow(kernel(rho_0), cmap=plt.cm.Blues, zorder=0, origin='lower')
49-
50-
5142
## Problem setup
5243

5344
congest_max = 1.01 * rho_0.max()
@@ -60,12 +51,6 @@ def mask_to_img(mask: np.ndarray):
6051
exit_img = mask_to_img(exit_mask)
6152
exit_img[exit_mask.astype(bool), 0] = .8
6253

63-
plt.figure()
64-
plt.imshow(mask_img, zorder=2)
65-
plt.imshow(exit_img, zorder=1, origin='lower')
66-
plt.imshow(rho_0, cmap=plt.cm.Blues, zorder=0, origin='lower')
67-
plt.show()
68-
6954
boundary_ = np.ma.MaskedArray(1. - exit_mask, mask=mask)
7055
potential = skfmm.travel_time(boundary_, np.ones_like(boundary_))
7156

@@ -74,10 +59,52 @@ def mask_to_img(mask: np.ndarray):
7459
plt.imshow(exit_img, zorder=1, origin='lower', extent=extent)
7560
ct = plt.contourf(potential, zorder=1, levels=40, extent=extent)
7661
plt.title("Potential function $\\Psi$")
77-
plt.show()
7862

79-
prox = mfg.prox.CongestionObstacleProx(mask, congest_max, potential)
8063

64+
terminal_prox = mfg.prox.CongestionObstacleProx(mask, congest_max, potential)
65+
running_prox = mfg.prox.CongestionObstacleProx(mask, congest_max, np.zeros_like(potential))
8166

67+
N_t = 31
68+
dt = 1./ (N_t - 1)
69+
epsilon = 0.1
70+
kernel = mfg.kernels.EuclideanKernel(
71+
nx, nx, 0, 1, 0, 1, epsilon * dt)
8272

8373

74+
sinkhorn = mfg.sinkhorn.MultiSinkhorn(
75+
running_prox, terminal_prox,
76+
kernel, rho_0)
77+
78+
79+
a_s = [
80+
np.ones_like(rho_0, order='F') for _ in range(N_t)
81+
]
82+
83+
print("Running sinkhorn...")
84+
import time
85+
t_a = time.time()
86+
num_iters = 1
87+
sinkhorn.run(a_s, num_iters)
88+
print("Elapsed time:", time.time() - t_a)
89+
90+
print("Computing marginals...")
91+
marginals = sinkhorn.get_marginals(a_s)
92+
93+
skip = 5
94+
steps_to_plot = list(np.arange(N_t)[::skip])
95+
96+
ncols = 3
97+
nrows = len(steps_to_plot) // 3
98+
99+
fig, axes = plt.subplots(nrows, ncols)
100+
axes = axes.ravel()
101+
102+
for i, t in enumerate(steps_to_plot):
103+
m = marginals[t]
104+
if i < len(axes):
105+
ax = axes[i]
106+
ax.imshow(mask_img, zorder=2, origin='lower', extent=extent)
107+
ax.imshow(m, zorder=1, origin='lower', extent=extent, cmap=plt.cm.Blues)
108+
ax.set_title("Time step $t=%d$" % t)
109+
110+
plt.show()

‎python/kernels.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ void bind_kernels(py::module& m) {
3434
"kernels", "Gibbs kernels used in the optimal transport formulation of the MFG. "
3535
"These represent the 2-marginal of the Wiener measure.");
3636

37-
py::class_<BaseKernel, PyBaseKernel>(m2, "BaseKernel")
37+
py::class_<BaseKernel, PyBaseKernel, KernelPtr>(m2, "BaseKernel")
3838
.def("__call__", &BaseKernel::operator());
3939

40-
py::class_<Kernel2D, BaseKernel>(m2, "EuclideanKernel")
40+
py::class_<Kernel2D, BaseKernel, std::shared_ptr<Kernel2D>>(m2, "EuclideanKernel")
4141
.def(py::init<size_t, size_t, double, double,
4242
double, double, double>());
4343

‎python/operators.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <pybind11/eigen.h>
33
#include <pybind11/operators.h>
44

5+
#include "omp.h"
56
#include "KLOperator.h"
67
#include "CongestionOperator.h"
78

@@ -31,16 +32,16 @@ void bind_prox(py::module& m) {
3132

3233
py::module m2 = m.def_submodule("prox", "Proximal operators");
3334

34-
py::class_<BaseProximalOperator, PyBaseProximalOperator>(m2, "BaseProximalOperator")
35+
py::class_<BaseProximalOperator, PyBaseProximalOperator, ProxPtr>(m2, "BaseProximalOperator")
3536
.def("__call__", &BaseProximalOperator::operator());
3637

37-
py::class_<ObstacleProx, BaseProximalOperator>(m2, "ObstacleProx")
38+
py::class_<ObstacleProx, BaseProximalOperator, std::shared_ptr<ObstacleProx>>(m2, "ObstacleProx")
3839
.def(py::init<const Eigen::ArrayXXi&>());
3940

40-
py::class_<CongestionPotentialProx, BaseProximalOperator>(m2, "CongestionPotentialProx")
41+
py::class_<CongestionPotentialProx, BaseProximalOperator, std::shared_ptr<CongestionPotentialProx>>(m2, "CongestionPotentialProx")
4142
.def(py::init<double, const Eigen::MatrixXd&>());
4243

43-
py::class_<CongestionObstacleProx, CongestionPotentialProx, ObstacleProx>(m2, "CongestionObstacleProx")
44+
py::class_<CongestionObstacleProx, CongestionPotentialProx, ObstacleProx, std::shared_ptr<CongestionObstacleProx>>(m2, "CongestionObstacleProx")
4445
.def(py::init<const ArrayXXi&, double, const Eigen::MatrixXd&>());
4546

4647
}

‎python/sinkhorn.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <pybind11/pybind11.h>
2+
#include <pybind11/stl.h>
23
#include <pybind11/eigen.h>
34

45
#include "MultiSinkhorn.h"
@@ -16,7 +17,14 @@ void bind_sinkhorn(py::module& m) {
1617
"to update the dual potentials and compute the marginals.");
1718

1819

19-
20+
py::class_<MultimarginalSinkhorn>(m2, "MultiSinkhorn")
21+
.def(py::init<ProxPtr, ProxPtr,
22+
kernels::KernelPtr,
23+
Eigen::MatrixXd&>())
24+
.def("iterate", &MultimarginalSinkhorn::iterate)
25+
.def("run", &MultimarginalSinkhorn::run)
26+
.def("get_marginals", &MultimarginalSinkhorn::get_marginals);
27+
2028

2129
}
2230

0 commit comments

Comments
 (0)