src/nodes/server/text-classification.ts (74 lines of code) (raw):
import type { VisualBlocksClassificationResult } from "@visualblocks/custom-node-types";
import { HfInference } from "@huggingface/inference";
import { NODE_SPEC } from "./text-classification-specs";
import { LitElement } from "lit";
import { compareObjects } from "../../utils";
declare interface Inputs {
text: string;
modelid: string;
apikey: string;
}
interface Outputs {
results: VisualBlocksClassificationResult;
}
class TextClassificationNode extends LitElement {
private cachedInputs?: Inputs;
private cachedOutput?: Outputs;
private hf?: HfInference;
constructor() {
super();
this.hf = new HfInference();
}
render() {
// This node doesn't have a preview UI.
}
async runWithInputs(inputs: Inputs) {
const { text, apikey, modelid } = inputs;
const _modelid = modelid?.trim();
if (this.hf && apikey) {
this.hf = new HfInference(apikey);
}
if (!text) {
this.dispatchEvent(
new CustomEvent("outputs", { detail: { results: null } })
);
return;
}
if (this.cachedOutput && compareObjects(this.cachedInputs, inputs)) {
this.dispatchEvent(
new CustomEvent("outputs", { detail: this.cachedOutput })
);
return;
}
try {
const textClassRes = await this.hf?.textClassification({
model: _modelid,
inputs: text,
});
if (!textClassRes) {
throw new Error("Invalid response");
}
// remap to visualblocks classification result
const result = textClassRes.map((e) => ({
className: e.label,
probability: e.score,
}));
const output: Outputs = {
results: { classes: result },
};
this.cachedOutput = output;
this.cachedInputs = inputs;
this.dispatchEvent(new CustomEvent("outputs", { detail: output }));
} catch (error: any) {
this.dispatchEvent(
new CustomEvent("outputs", {
detail: {
error: {
title: "Error",
message: error.message,
},
},
})
);
}
}
}
export default { nodeSpec: NODE_SPEC, nodeImpl: TextClassificationNode };