Skip to content

Commit 6e1d658

Browse files
committed
Support custom model URL
1 parent 35e5986 commit 6e1d658

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

toxicity/src/index.ts

+10-4
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ declare interface ModelInputs extends tf.NamedTensorMap {
3434
* to detect. Labels must be one of `toxicity` | `severe_toxicity` |
3535
* `identity_attack` | `insult` | `threat` | `sexual_explicit` | `obscene`.
3636
* Defaults to all labels.
37+
* @param modelURL. URL to load model from. Defaults to tfhub.dev.
3738
*/
38-
export async function load(threshold: number, toxicityLabels: string[]) {
39+
export async function load(threshold: number, toxicityLabels: string[],
40+
modelURL: 'https://tfhub.dev/tensorflow/tfjs-model/toxicity/1/default/1') {
3941
const model = new ToxicityClassifier(threshold, toxicityLabels);
42+
model.setModelURL(modelURL);
4043
await model.load();
4144
return model;
4245
}
@@ -47,16 +50,19 @@ export class ToxicityClassifier {
4750
private labels: string[];
4851
private threshold: number;
4952
private toxicityLabels: string[];
53+
private modelURL: string;
5054

5155
constructor(threshold = 0.85, toxicityLabels: string[] = []) {
5256
this.threshold = threshold;
5357
this.toxicityLabels = toxicityLabels;
5458
}
5559

60+
setModelURL(url: string) {
61+
this.modelURL = url
62+
}
63+
5664
async loadModel() {
57-
return tfconv.loadGraphModel(
58-
'https://tfhub.dev/tensorflow/tfjs-model/toxicity/1/default/1',
59-
{fromTFHub: true});
65+
return tfconv.loadGraphModel(this.modelURL, {fromTFHub: true});
6066
}
6167

6268
async loadTokenizer() {

0 commit comments

Comments
 (0)