@@ -34,9 +34,12 @@ declare interface ModelInputs extends tf.NamedTensorMap {
34
34
* to detect. Labels must be one of `toxicity` | `severe_toxicity` |
35
35
* `identity_attack` | `insult` | `threat` | `sexual_explicit` | `obscene`.
36
36
* Defaults to all labels.
37
+ * @param modelURL. URL to load model from. Defaults to tfhub.dev.
37
38
*/
38
- export async function load ( threshold : number , toxicityLabels : string [ ] ) {
39
+ export async function load ( threshold : number , toxicityLabels : string [ ] ,
40
+ modelURL : 'https://tfhub.dev/tensorflow/tfjs-model/toxicity/1/default/1' ) {
39
41
const model = new ToxicityClassifier ( threshold , toxicityLabels ) ;
42
+ model . setModelURL ( modelURL ) ;
40
43
await model . load ( ) ;
41
44
return model ;
42
45
}
@@ -47,16 +50,19 @@ export class ToxicityClassifier {
47
50
private labels : string [ ] ;
48
51
private threshold : number ;
49
52
private toxicityLabels : string [ ] ;
53
+ private modelURL : string ;
50
54
51
55
constructor ( threshold = 0.85 , toxicityLabels : string [ ] = [ ] ) {
52
56
this . threshold = threshold ;
53
57
this . toxicityLabels = toxicityLabels ;
54
58
}
55
59
60
+ setModelURL ( url : string ) {
61
+ this . modelURL = url
62
+ }
63
+
56
64
async loadModel ( ) {
57
- return tfconv . loadGraphModel (
58
- 'https://tfhub.dev/tensorflow/tfjs-model/toxicity/1/default/1' ,
59
- { fromTFHub : true } ) ;
65
+ return tfconv . loadGraphModel ( this . modelURL , { fromTFHub : true } ) ;
60
66
}
61
67
62
68
async loadTokenizer ( ) {
0 commit comments