in x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java [38:384]
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) implements InferenceServiceResults {
public static final String NAME = "chat_completion_chunk";
public static final String MODEL_FIELD = "model";
public static final String OBJECT_FIELD = "object";
public static final String USAGE_FIELD = "usage";
public static final String INDEX_FIELD = "index";
public static final String ID_FIELD = "id";
public static final String FUNCTION_NAME_FIELD = "name";
public static final String FUNCTION_ARGUMENTS_FIELD = "arguments";
public static final String FUNCTION_FIELD = "function";
public static final String CHOICES_FIELD = "choices";
public static final String DELTA_FIELD = "delta";
public static final String CONTENT_FIELD = "content";
public static final String REFUSAL_FIELD = "refusal";
public static final String ROLE_FIELD = "role";
private static final String TOOL_CALLS_FIELD = "tool_calls";
public static final String FINISH_REASON_FIELD = "finish_reason";
public static final String COMPLETION_TOKENS_FIELD = "completion_tokens";
public static final String TOTAL_TOKENS_FIELD = "total_tokens";
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
public static final String TYPE_FIELD = "type";
/**
* OpenAI Spec only returns one result at a time, and Chat Completion adheres to that spec as much as possible.
* So we will insert a buffer in between the upstream data and the downstream client so that we only send one request at a time.
*/
public StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) {
Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
AtomicBoolean onComplete = new AtomicBoolean();
this.publisher = downstream -> {
publisher.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
downstream.onSubscribe(new Flow.Subscription() {
@Override
public void request(long n) {
var nextItem = buffer.poll();
if (nextItem != null) {
downstream.onNext(new Results(DequeUtils.of(nextItem)));
} else if (onComplete.get()) {
downstream.onComplete();
} else {
subscription.request(n);
}
}
@Override
public void cancel() {
subscription.cancel();
}
});
}
@Override
public void onNext(Results item) {
var chunks = item.chunks();
var firstItem = chunks.poll();
chunks.forEach(buffer::offer);
downstream.onNext(new Results(DequeUtils.of(firstItem)));
}
@Override
public void onError(Throwable throwable) {
downstream.onError(throwable);
}
@Override
public void onComplete() {
// only complete if the buffer is empty, so that the client has a chance to drain the buffer
if (onComplete.compareAndSet(false, true)) {
if (buffer.isEmpty()) {
downstream.onComplete();
}
}
}
});
};
}
@Override
public boolean isStreaming() {
return true;
}
@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
throw new UnsupportedOperationException("Not implemented");
}
public record Results(Deque<ChatCompletionChunk> chunks) implements InferenceServiceResults.Result {
public static String NAME = "streaming_unified_chat_completion_results";
public Results(StreamInput in) throws IOException {
this(readDeque(in, ChatCompletionChunk::new));
}
@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params));
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(chunks, StreamOutput::writeWriteable);
}
@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Results results = (Results) o;
return dequeEquals(chunks, results.chunks());
}
@Override
public int hashCode() {
return dequeHashCode(chunks);
}
}
public record ChatCompletionChunk(String id, List<Choice> choices, String model, String object, ChatCompletionChunk.Usage usage)
implements
ChunkedToXContent,
Writeable {
private ChatCompletionChunk(StreamInput in) throws IOException {
this(
in.readString(),
in.readOptionalCollectionAsList(Choice::new),
in.readString(),
in.readString(),
in.readOptional(Usage::new)
);
}
@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return Iterators.concat(
ChunkedToXContentHelper.startObject(),
chunk((b, p) -> b.field(ID_FIELD, id)),
choices != null ? ChunkedToXContentHelper.array(CHOICES_FIELD, choices.iterator(), params) : Collections.emptyIterator(),
chunk((b, p) -> b.field(MODEL_FIELD, model).field(OBJECT_FIELD, object)),
usage != null
? chunk(
(b, p) -> b.startObject(USAGE_FIELD)
.field(COMPLETION_TOKENS_FIELD, usage.completionTokens())
.field(PROMPT_TOKENS_FIELD, usage.promptTokens())
.field(TOTAL_TOKENS_FIELD, usage.totalTokens())
.endObject()
)
: Collections.emptyIterator(),
ChunkedToXContentHelper.endObject()
);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(id);
out.writeOptionalCollection(choices);
out.writeString(model);
out.writeString(object);
out.writeOptionalWriteable(usage);
}
public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index)
implements
ChunkedToXContentObject,
Writeable {
private Choice(StreamInput in) throws IOException {
this(new Delta(in), in.readOptionalString(), in.readInt());
}
/*
choices: Array<{
delta: { ... };
finish_reason: string | null;
index: number;
}>;
*/
@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return Iterators.concat(
ChunkedToXContentHelper.startObject(),
delta.toXContentChunked(params),
optionalField(FINISH_REASON_FIELD, finishReason),
chunk((b, p) -> b.field(INDEX_FIELD, index)),
ChunkedToXContentHelper.endObject()
);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeWriteable(delta);
out.writeOptionalString(finishReason);
out.writeInt(index);
}
public record Delta(String content, String refusal, String role, List<ToolCall> toolCalls) implements Writeable {
private Delta(StreamInput in) throws IOException {
this(
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalCollectionAsList(ToolCall::new)
);
}
/*
delta: {
content?: string | null;
refusal?: string | null;
role?: 'system' | 'user' | 'assistant' | 'tool';
tool_calls?: Array<{ ... }>;
};
*/
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
var xContent = Iterators.concat(
ChunkedToXContentHelper.startObject(DELTA_FIELD),
optionalField(CONTENT_FIELD, content),
optionalField(REFUSAL_FIELD, refusal),
optionalField(ROLE_FIELD, role)
);
if (toolCalls != null && toolCalls.isEmpty() == false) {
xContent = Iterators.concat(
xContent,
ChunkedToXContentHelper.startArray(TOOL_CALLS_FIELD),
Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)),
ChunkedToXContentHelper.endArray()
);
}
xContent = Iterators.concat(xContent, ChunkedToXContentHelper.endObject());
return xContent;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(content);
out.writeOptionalString(refusal);
out.writeOptionalString(role);
out.writeOptionalCollection(toolCalls);
}
public record ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type)
implements
ChunkedToXContentObject,
Writeable {
private ToolCall(StreamInput in) throws IOException {
this(
in.readInt(),
in.readOptionalString(),
in.readOptional(ChatCompletionChunk.Choice.Delta.ToolCall.Function::new),
in.readOptionalString()
);
}
/*
index: number;
id?: string;
function?: {
arguments?: string;
name?: string;
};
type?: 'function';
*/
@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
var content = Iterators.concat(
ChunkedToXContentHelper.startObject(),
chunk((b, p) -> b.field(INDEX_FIELD, index)),
optionalField(ID_FIELD, id)
);
if (function != null) {
content = Iterators.concat(
content,
ChunkedToXContentHelper.startObject(FUNCTION_FIELD),
optionalField(FUNCTION_ARGUMENTS_FIELD, function.arguments()),
optionalField(FUNCTION_NAME_FIELD, function.name()),
ChunkedToXContentHelper.endObject()
);
}
content = Iterators.concat(
content,
ChunkedToXContentHelper.chunk((b, p) -> b.field(TYPE_FIELD, type)),
ChunkedToXContentHelper.endObject()
);
return content;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(index);
out.writeOptionalString(id);
out.writeOptionalWriteable(function);
out.writeOptionalString(type);
}
public record Function(String arguments, String name) implements Writeable {
private Function(StreamInput in) throws IOException {
this(in.readOptionalString(), in.readOptionalString());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(arguments);
out.writeOptionalString(name);
}
}
}
}
}
public record Usage(int completionTokens, int promptTokens, int totalTokens) implements Writeable {
private Usage(StreamInput in) throws IOException {
this(in.readInt(), in.readInt(), in.readInt());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(completionTokens);
out.writeInt(promptTokens);
out.writeInt(totalTokens);
}
}
private static Iterator<ToXContent> optionalField(String name, String value) {
if (value == null) {
return Collections.emptyIterator();
} else {
return ChunkedToXContentHelper.chunk((b, p) -> b.field(name, value));
}
}
}
}