|
| 1 | +open Common |
| 2 | + |
| 3 | +let add_err (e : 'e) = function Error es -> Error (es @ [e]) | Ok _ -> Error [e] |
| 4 | + |
| 5 | +module Debug = struct |
| 6 | + let disabled_ fmt = Stdlib.(Printf.ifprintf stderr fmt) |
| 7 | + |
| 8 | + let enabled_ fmt = |
| 9 | + Stdlib.( |
| 10 | + Printf.ksprintf |
| 11 | + (fun s -> |
| 12 | + prerr_string s; |
| 13 | + prerr_newline ()) |
| 14 | + fmt) |
| 15 | + |
| 16 | + let d_ = disabled_ |
| 17 | + let i_ = enabled_ |
| 18 | +end |
| 19 | + |
| 20 | +open Debug |
| 21 | + |
| 22 | +module ThreadH = Hashtbl.Make (struct |
| 23 | + type t = Thread.t |
| 24 | + |
| 25 | + let equal t1 t2 = compare (Thread.id t1) (Thread.id t2) = 0 |
| 26 | + let hash t = Hashtbl.hash (Thread.id t) |
| 27 | +end) |
| 28 | + |
| 29 | +module Counter = struct |
| 30 | + |
| 31 | + open Ctx_util |
| 32 | + open Ctx_util.Syntax |
| 33 | + |
| 34 | + let lock_if b m = if b then Mutex.protect m else empty_context' () |
| 35 | + |
| 36 | + (** Note: we enforce that spawned threads don't raise uncaught exceptions, |
| 37 | + which in theory changes the semantics of threads. The value of being |
| 38 | + able to report stray exceptions outweighs the slim chance anyone |
| 39 | + would rely on being able to ignore exceptions in threads. *) |
| 40 | + |
| 41 | + type 'a finished = Return of 'a | Uncaught of Util.exn_info | Overflow of int |
| 42 | + type 'a state = Running | Finished of 'a finished |
| 43 | + |
| 44 | + type 'a group = { |
| 45 | + mutable state : 'a state; |
| 46 | + finished : Condition.t; (* predicate: state <> Running, mutex: owner.mut *) |
| 47 | + mutable thread_count : int; |
| 48 | + thread_limit : int option; |
| 49 | + owner : t; |
| 50 | + } |
| 51 | + |
| 52 | + and g = G : 'a group -> g |
| 53 | + and t = { mut : Mutex.t; groups : g ThreadH.t } |
| 54 | + |
| 55 | + let finish ?(lock = true) group fin = |
| 56 | + d_ "finish"; |
| 57 | + |
| 58 | + let< _ = lock_if lock group.owner.mut in |
| 59 | + if group.state <> Running then failwith "finish: already finished"; |
| 60 | + group.state <- Finished fin; |
| 61 | + Condition.broadcast group.finished |
| 62 | + |
| 63 | + let try_finish group fin = |
| 64 | + let< _ = Mutex.protect group.owner.mut in |
| 65 | + if group.state = Running then |
| 66 | + finish ~lock:false group fin |
| 67 | + |
| 68 | + let try_return group x = try_finish group (Return x) |
| 69 | + |
| 70 | + let spawn_thread ?(lock = true) ?group cnt (f : _ -> unit) x = |
| 71 | + d_ "create_thread"; |
| 72 | + |
| 73 | + let make_thread group f x = |
| 74 | + let cnt = group.owner in |
| 75 | + let tid = |
| 76 | + Thread.create |
| 77 | + (fun () -> |
| 78 | + d_ "thread#%d started" Thread.(self () |> id); |
| 79 | + Fun.protect |
| 80 | + (fun () -> |
| 81 | + Util.try_to_result f x |
| 82 | + |> Result.iter_error (fun e -> try_finish group (Uncaught e))) |
| 83 | + ~finally:(fun () -> |
| 84 | + let< _ = Mutex.protect cnt.mut in |
| 85 | + let tid = Thread.self () in |
| 86 | + group.thread_count <- group.thread_count - 1; |
| 87 | + ThreadH.remove cnt.groups tid; |
| 88 | + d_ "thread#%d finished" Thread.(self () |> id))) |
| 89 | + () |
| 90 | + in |
| 91 | + ThreadH.replace cnt.groups tid (G group); |
| 92 | + group.thread_count <- group.thread_count + 1; |
| 93 | + tid |
| 94 | + in |
| 95 | + |
| 96 | + let spawn = fun group -> |
| 97 | + match group.state, group.thread_limit with |
| 98 | + | Running, Some limit when group.thread_count >= limit -> |
| 99 | + finish ~lock:false group (Overflow limit); Thread.self () |
| 100 | + | Running, _ -> make_thread group f x |
| 101 | + | Finished _, _ -> Thread.self () |
| 102 | + in |
| 103 | + |
| 104 | + let< _ = lock_if lock cnt.mut in |
| 105 | + match group with |
| 106 | + | Some g -> spawn g |
| 107 | + | None -> let G g = ThreadH.find cnt.groups (Thread.self ()) in spawn g |
| 108 | + |
| 109 | + let create_counter () = { mut = Mutex.create (); groups = ThreadH.create 32 } |
| 110 | + |
| 111 | + let create_group ?thread_limit cnt = |
| 112 | + d_ "create_group"; |
| 113 | + (* no lock needed, since we don't write to the owner; |
| 114 | + operations on the group will lock instead *) |
| 115 | + { |
| 116 | + state = Running; |
| 117 | + finished = Condition.create (); |
| 118 | + thread_count = 0; |
| 119 | + thread_limit; |
| 120 | + owner = cnt; |
| 121 | + } |
| 122 | + |
| 123 | + let get_thread_count group = |
| 124 | + let< _ = Mutex.protect group.owner.mut in group.thread_count |
| 125 | + |
| 126 | + (** Wait for threads in a group to complete. Group must be finished first. *) |
| 127 | + let join_group ~leftover_thread_limit ~timeout group = |
| 128 | + (* busy waits to implement timeout *) |
| 129 | + d_ "join_group"; |
| 130 | + |
| 131 | + let _ = |
| 132 | + let< _ = Mutex.protect group.owner.mut in |
| 133 | + if group.state = Running then failwith "join_group: still running" |
| 134 | + in |
| 135 | + |
| 136 | + let time0 = Mtime_clock.counter () in |
| 137 | + let rec loop () = |
| 138 | + let remaining = get_thread_count group in |
| 139 | + if remaining <= leftover_thread_limit then remaining |
| 140 | + else if Mtime.Span.compare (Mtime_clock.count time0) timeout >= 0 then remaining |
| 141 | + else (Thread.yield (); loop ()) |
| 142 | + in |
| 143 | + loop () |
| 144 | + |
| 145 | + (** {1 High-level group operations} *) |
| 146 | + |
| 147 | + type thread_group_err = |
| 148 | + | ThreadLimitReached of int |
| 149 | + | ThreadsLeftOver of { left_over : int; limit : int } |
| 150 | + | ExceptionRaised of { main : bool; exn_info : Util.exn_info } |
| 151 | + |
| 152 | + (** Create a group that runs the given function, then sets the return value |
| 153 | + as the return value of the group. *) |
| 154 | + let spawn_thread_group ?thread_limit cnt f x = |
| 155 | + let group = create_group ?thread_limit cnt in |
| 156 | + spawn_thread ~group cnt Util.(try_return group % try_to_result f) x |> ignore; |
| 157 | + group |
| 158 | + |
| 159 | + (** Wait for a group to finish then join its threads ({!join_group}). *) |
| 160 | + let collect_thread_group ?leftover_limit group = |
| 161 | + let cnt = group.owner in |
| 162 | + let fin = |
| 163 | + let rec loop () = match group.state with |
| 164 | + | Running -> d_ "still running"; Condition.wait group.finished cnt.mut; loop () |
| 165 | + | Finished fin -> fin |
| 166 | + in |
| 167 | + let< _ = Mutex.protect cnt.mut in loop () |
| 168 | + in |
| 169 | + |
| 170 | + let leftover_count = |
| 171 | + match leftover_limit with |
| 172 | + | Some (n, t) -> join_group ~leftover_thread_limit:n ~timeout:t group |
| 173 | + | None -> get_thread_count group |
| 174 | + in |
| 175 | + |
| 176 | + fin, leftover_count |
| 177 | + |
| 178 | + (** Check the return value of {!collect_thread_group} *) |
| 179 | + let check_thread_group_result ?leftover_limit fin leftover_count = |
| 180 | + let r = match fin with |
| 181 | + | Return (Ok x) -> Ok x |
| 182 | + | Return (Error e) -> Error [ExceptionRaised { main = true; exn_info = e }] |
| 183 | + | Uncaught e -> Error [ExceptionRaised { main = false; exn_info = e }] |
| 184 | + | Overflow l -> Error [ThreadLimitReached l] |
| 185 | + in |
| 186 | + |
| 187 | + match leftover_limit with |
| 188 | + | Some (lim, (_ : Mtime.span)) when leftover_count > lim -> |
| 189 | + r |> add_err (ThreadsLeftOver { left_over = leftover_count ; limit = lim }) |
| 190 | + | _ -> r |
| 191 | + |
| 192 | + (** Combines {!spawn_thread_group}, {!collect_thread_group}, and {!check_thread_group_result}.*) |
| 193 | + let run_in_thread_group ?thread_limit ?leftover_limit cnt f x = |
| 194 | + let group = spawn_thread_group cnt f x ?thread_limit in |
| 195 | + let fin, leftover_count = collect_thread_group group ?leftover_limit in |
| 196 | + check_thread_group_result ?leftover_limit fin leftover_count |
| 197 | + |
| 198 | + (* the rest is for error reporting, could do with less code/abstraction... *) |
| 199 | + |
| 200 | + let string_of_thread_group_err = function |
| 201 | + | ThreadLimitReached n -> |
| 202 | + Printf.sprintf "Too many threads were used (> %d)" n |
| 203 | + | ThreadsLeftOver { left_over = n; limit } -> |
| 204 | + Printf.sprintf "Too many threads were left running (%d > %d)" n limit |
| 205 | + | ExceptionRaised { main; exn_info = Util.{exn; backtrace} } -> |
| 206 | + Printf.sprintf "%s thread raised an exception: %s\n%s" |
| 207 | + (if main then "The main" else "A created") |
| 208 | + (Printexc.to_string exn) (Printexc.raw_backtrace_to_string backtrace) |
| 209 | + |> String.trim |
| 210 | + |
| 211 | + let string_of_thread_group_errs = function |
| 212 | + | [] -> "Unknown error in a thread group" (* this shouldn't happen *) |
| 213 | + | [err] -> "Error in a thread group: " ^ string_of_thread_group_err err |
| 214 | + | errs -> |
| 215 | + "Multiple errors in a thread group:\n" ^ |
| 216 | + (errs |> List.mapi (fun i err -> [ |
| 217 | + Printf.sprintf "+----------- %d -----------+" (i + 1); |
| 218 | + string_of_thread_group_err err]) |
| 219 | + |> List.concat |> String.concat "\n") |
| 220 | + |
| 221 | + exception ThreadGroupErrs of thread_group_err list |
| 222 | + |
| 223 | + let _ = |
| 224 | + Printexc.register_printer |
| 225 | + (function ThreadGroupErrs errs -> Some (string_of_thread_group_errs errs) | _ -> None) |
| 226 | + |
| 227 | + let thread_group_result_to_exn = function |
| 228 | + | Ok x -> x |
| 229 | + | Error errs -> raise (ThreadGroupErrs errs) |
| 230 | + |
| 231 | +end |
| 232 | + |
| 233 | +module CounterInstance () = struct |
| 234 | + let instance = Counter.create_counter () |
| 235 | + |
| 236 | + (** Like {!Stdlib.Thread}, but with counted threads *) |
| 237 | + module Thread = struct |
| 238 | + include Thread |
| 239 | + let create f x = Counter.spawn_thread instance Util.(ignore % f) x |
| 240 | + end |
| 241 | +end |
0 commit comments