Skip to content

Commit e47c06d

Browse files
authored
Add thread counter (#21)
* For tracking and restricting threads spawned by student code
1 parent 6870040 commit e47c06d

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed
+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
(library
2+
(name thread_counter)
3+
(public_name less-power.thread-counter)
4+
(libraries common mtime mtime.clock.os threads))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
open Common
2+
3+
let add_err (e : 'e) = function Error es -> Error (es @ [e]) | Ok _ -> Error [e]
4+
5+
module Debug = struct
6+
let disabled_ fmt = Stdlib.(Printf.ifprintf stderr fmt)
7+
8+
let enabled_ fmt =
9+
Stdlib.(
10+
Printf.ksprintf
11+
(fun s ->
12+
prerr_string s;
13+
prerr_newline ())
14+
fmt)
15+
16+
let d_ = disabled_
17+
let i_ = enabled_
18+
end
19+
20+
open Debug
21+
22+
module ThreadH = Hashtbl.Make (struct
23+
type t = Thread.t
24+
25+
let equal t1 t2 = compare (Thread.id t1) (Thread.id t2) = 0
26+
let hash t = Hashtbl.hash (Thread.id t)
27+
end)
28+
29+
module Counter = struct
30+
31+
open Ctx_util
32+
open Ctx_util.Syntax
33+
34+
let lock_if b m = if b then Mutex.protect m else empty_context' ()
35+
36+
(** Note: we enforce that spawned threads don't raise uncaught exceptions,
37+
which in theory changes the semantics of threads. The value of being
38+
able to report stray exceptions outweighs the slim chance anyone
39+
would rely on being able to ignore exceptions in threads. *)
40+
41+
type 'a finished = Return of 'a | Uncaught of Util.exn_info | Overflow of int
42+
type 'a state = Running | Finished of 'a finished
43+
44+
type 'a group = {
45+
mutable state : 'a state;
46+
finished : Condition.t; (* predicate: state <> Running, mutex: owner.mut *)
47+
mutable thread_count : int;
48+
thread_limit : int option;
49+
owner : t;
50+
}
51+
52+
and g = G : 'a group -> g
53+
and t = { mut : Mutex.t; groups : g ThreadH.t }
54+
55+
let finish ?(lock = true) group fin =
56+
d_ "finish";
57+
58+
let< _ = lock_if lock group.owner.mut in
59+
if group.state <> Running then failwith "finish: already finished";
60+
group.state <- Finished fin;
61+
Condition.broadcast group.finished
62+
63+
let try_finish group fin =
64+
let< _ = Mutex.protect group.owner.mut in
65+
if group.state = Running then
66+
finish ~lock:false group fin
67+
68+
let try_return group x = try_finish group (Return x)
69+
70+
let spawn_thread ?(lock = true) ?group cnt (f : _ -> unit) x =
71+
d_ "create_thread";
72+
73+
let make_thread group f x =
74+
let cnt = group.owner in
75+
let tid =
76+
Thread.create
77+
(fun () ->
78+
d_ "thread#%d started" Thread.(self () |> id);
79+
Fun.protect
80+
(fun () ->
81+
Util.try_to_result f x
82+
|> Result.iter_error (fun e -> try_finish group (Uncaught e)))
83+
~finally:(fun () ->
84+
let< _ = Mutex.protect cnt.mut in
85+
let tid = Thread.self () in
86+
group.thread_count <- group.thread_count - 1;
87+
ThreadH.remove cnt.groups tid;
88+
d_ "thread#%d finished" Thread.(self () |> id)))
89+
()
90+
in
91+
ThreadH.replace cnt.groups tid (G group);
92+
group.thread_count <- group.thread_count + 1;
93+
tid
94+
in
95+
96+
let spawn = fun group ->
97+
match group.state, group.thread_limit with
98+
| Running, Some limit when group.thread_count >= limit ->
99+
finish ~lock:false group (Overflow limit); Thread.self ()
100+
| Running, _ -> make_thread group f x
101+
| Finished _, _ -> Thread.self ()
102+
in
103+
104+
let< _ = lock_if lock cnt.mut in
105+
match group with
106+
| Some g -> spawn g
107+
| None -> let G g = ThreadH.find cnt.groups (Thread.self ()) in spawn g
108+
109+
let create_counter () = { mut = Mutex.create (); groups = ThreadH.create 32 }
110+
111+
let create_group ?thread_limit cnt =
112+
d_ "create_group";
113+
(* no lock needed, since we don't write to the owner;
114+
operations on the group will lock instead *)
115+
{
116+
state = Running;
117+
finished = Condition.create ();
118+
thread_count = 0;
119+
thread_limit;
120+
owner = cnt;
121+
}
122+
123+
let get_thread_count group =
124+
let< _ = Mutex.protect group.owner.mut in group.thread_count
125+
126+
(** Wait for threads in a group to complete. Group must be finished first. *)
127+
let join_group ~leftover_thread_limit ~timeout group =
128+
(* busy waits to implement timeout *)
129+
d_ "join_group";
130+
131+
let _ =
132+
let< _ = Mutex.protect group.owner.mut in
133+
if group.state = Running then failwith "join_group: still running"
134+
in
135+
136+
let time0 = Mtime_clock.counter () in
137+
let rec loop () =
138+
let remaining = get_thread_count group in
139+
if remaining <= leftover_thread_limit then remaining
140+
else if Mtime.Span.compare (Mtime_clock.count time0) timeout >= 0 then remaining
141+
else (Thread.yield (); loop ())
142+
in
143+
loop ()
144+
145+
(** {1 High-level group operations} *)
146+
147+
type thread_group_err =
148+
| ThreadLimitReached of int
149+
| ThreadsLeftOver of { left_over : int; limit : int }
150+
| ExceptionRaised of { main : bool; exn_info : Util.exn_info }
151+
152+
(** Create a group that runs the given function, then sets the return value
153+
as the return value of the group. *)
154+
let spawn_thread_group ?thread_limit cnt f x =
155+
let group = create_group ?thread_limit cnt in
156+
spawn_thread ~group cnt Util.(try_return group % try_to_result f) x |> ignore;
157+
group
158+
159+
(** Wait for a group to finish then join its threads ({!join_group}). *)
160+
let collect_thread_group ?leftover_limit group =
161+
let cnt = group.owner in
162+
let fin =
163+
let rec loop () = match group.state with
164+
| Running -> d_ "still running"; Condition.wait group.finished cnt.mut; loop ()
165+
| Finished fin -> fin
166+
in
167+
let< _ = Mutex.protect cnt.mut in loop ()
168+
in
169+
170+
let leftover_count =
171+
match leftover_limit with
172+
| Some (n, t) -> join_group ~leftover_thread_limit:n ~timeout:t group
173+
| None -> get_thread_count group
174+
in
175+
176+
fin, leftover_count
177+
178+
(** Check the return value of {!collect_thread_group} *)
179+
let check_thread_group_result ?leftover_limit fin leftover_count =
180+
let r = match fin with
181+
| Return (Ok x) -> Ok x
182+
| Return (Error e) -> Error [ExceptionRaised { main = true; exn_info = e }]
183+
| Uncaught e -> Error [ExceptionRaised { main = false; exn_info = e }]
184+
| Overflow l -> Error [ThreadLimitReached l]
185+
in
186+
187+
match leftover_limit with
188+
| Some (lim, (_ : Mtime.span)) when leftover_count > lim ->
189+
r |> add_err (ThreadsLeftOver { left_over = leftover_count ; limit = lim })
190+
| _ -> r
191+
192+
(** Combines {!spawn_thread_group}, {!collect_thread_group}, and {!check_thread_group_result}.*)
193+
let run_in_thread_group ?thread_limit ?leftover_limit cnt f x =
194+
let group = spawn_thread_group cnt f x ?thread_limit in
195+
let fin, leftover_count = collect_thread_group group ?leftover_limit in
196+
check_thread_group_result ?leftover_limit fin leftover_count
197+
198+
(* the rest is for error reporting, could do with less code/abstraction... *)
199+
200+
let string_of_thread_group_err = function
201+
| ThreadLimitReached n ->
202+
Printf.sprintf "Too many threads were used (> %d)" n
203+
| ThreadsLeftOver { left_over = n; limit } ->
204+
Printf.sprintf "Too many threads were left running (%d > %d)" n limit
205+
| ExceptionRaised { main; exn_info = Util.{exn; backtrace} } ->
206+
Printf.sprintf "%s thread raised an exception: %s\n%s"
207+
(if main then "The main" else "A created")
208+
(Printexc.to_string exn) (Printexc.raw_backtrace_to_string backtrace)
209+
|> String.trim
210+
211+
let string_of_thread_group_errs = function
212+
| [] -> "Unknown error in a thread group" (* this shouldn't happen *)
213+
| [err] -> "Error in a thread group: " ^ string_of_thread_group_err err
214+
| errs ->
215+
"Multiple errors in a thread group:\n" ^
216+
(errs |> List.mapi (fun i err -> [
217+
Printf.sprintf "+----------- %d -----------+" (i + 1);
218+
string_of_thread_group_err err])
219+
|> List.concat |> String.concat "\n")
220+
221+
exception ThreadGroupErrs of thread_group_err list
222+
223+
let _ =
224+
Printexc.register_printer
225+
(function ThreadGroupErrs errs -> Some (string_of_thread_group_errs errs) | _ -> None)
226+
227+
let thread_group_result_to_exn = function
228+
| Ok x -> x
229+
| Error errs -> raise (ThreadGroupErrs errs)
230+
231+
end
232+
233+
module CounterInstance () = struct
234+
let instance = Counter.create_counter ()
235+
236+
(** Like {!Stdlib.Thread}, but with counted threads *)
237+
module Thread = struct
238+
include Thread
239+
let create f x = Counter.spawn_thread instance Util.(ignore % f) x
240+
end
241+
end

0 commit comments

Comments
 (0)