15
15
xg , yg = np .meshgrid (xar , xar )
16
16
extent = [xg .min (), xg .max (), yg .min (), yg .max ()]
17
17
18
- epsilon = 0.001
19
- kernel = mfg .kernels .EuclideanKernel (
20
- nx , nx , 0 , 1 , 0 , 1 , epsilon )
21
18
22
19
23
20
# Define domain mask
@@ -39,15 +36,9 @@ def mask_to_img(mask: np.ndarray):
39
36
rho_0 /= rho_0 .sum ()
40
37
41
38
fig = plt .figure ()
42
- plt .subplot (1 ,2 ,1 )
43
39
plt .imshow (mask_img , zorder = 2 )
44
40
plt .imshow (rho_0 , cmap = plt .cm .Blues , zorder = 0 , origin = 'lower' )
45
41
46
- plt .subplot (1 ,2 ,2 )
47
- plt .imshow (mask_img , zorder = 2 )
48
- plt .imshow (kernel (rho_0 ), cmap = plt .cm .Blues , zorder = 0 , origin = 'lower' )
49
-
50
-
51
42
## Problem setup
52
43
53
44
congest_max = 1.01 * rho_0 .max ()
@@ -60,12 +51,6 @@ def mask_to_img(mask: np.ndarray):
60
51
exit_img = mask_to_img (exit_mask )
61
52
exit_img [exit_mask .astype (bool ), 0 ] = .8
62
53
63
- plt .figure ()
64
- plt .imshow (mask_img , zorder = 2 )
65
- plt .imshow (exit_img , zorder = 1 , origin = 'lower' )
66
- plt .imshow (rho_0 , cmap = plt .cm .Blues , zorder = 0 , origin = 'lower' )
67
- plt .show ()
68
-
69
54
boundary_ = np .ma .MaskedArray (1. - exit_mask , mask = mask )
70
55
potential = skfmm .travel_time (boundary_ , np .ones_like (boundary_ ))
71
56
@@ -74,10 +59,52 @@ def mask_to_img(mask: np.ndarray):
74
59
plt .imshow (exit_img , zorder = 1 , origin = 'lower' , extent = extent )
75
60
ct = plt .contourf (potential , zorder = 1 , levels = 40 , extent = extent )
76
61
plt .title ("Potential function $\\ Psi$" )
77
- plt .show ()
78
62
79
- prox = mfg .prox .CongestionObstacleProx (mask , congest_max , potential )
80
63
64
+ terminal_prox = mfg .prox .CongestionObstacleProx (mask , congest_max , potential )
65
+ running_prox = mfg .prox .CongestionObstacleProx (mask , congest_max , np .zeros_like (potential ))
81
66
67
+ N_t = 31
68
+ dt = 1. / (N_t - 1 )
69
+ epsilon = 0.1
70
+ kernel = mfg .kernels .EuclideanKernel (
71
+ nx , nx , 0 , 1 , 0 , 1 , epsilon * dt )
82
72
83
73
74
+ sinkhorn = mfg .sinkhorn .MultiSinkhorn (
75
+ running_prox , terminal_prox ,
76
+ kernel , rho_0 )
77
+
78
+
79
+ a_s = [
80
+ np .ones_like (rho_0 , order = 'F' ) for _ in range (N_t )
81
+ ]
82
+
83
+ print ("Running sinkhorn..." )
84
+ import time
85
+ t_a = time .time ()
86
+ num_iters = 1
87
+ sinkhorn .run (a_s , num_iters )
88
+ print ("Elapsed time:" , time .time () - t_a )
89
+
90
+ print ("Computing marginals..." )
91
+ marginals = sinkhorn .get_marginals (a_s )
92
+
93
+ skip = 5
94
+ steps_to_plot = list (np .arange (N_t )[::skip ])
95
+
96
+ ncols = 3
97
+ nrows = len (steps_to_plot ) // 3
98
+
99
+ fig , axes = plt .subplots (nrows , ncols )
100
+ axes = axes .ravel ()
101
+
102
+ for i , t in enumerate (steps_to_plot ):
103
+ m = marginals [t ]
104
+ if i < len (axes ):
105
+ ax = axes [i ]
106
+ ax .imshow (mask_img , zorder = 2 , origin = 'lower' , extent = extent )
107
+ ax .imshow (m , zorder = 1 , origin = 'lower' , extent = extent , cmap = plt .cm .Blues )
108
+ ax .set_title ("Time step $t=%d$" % t )
109
+
110
+ plt .show ()
0 commit comments