-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Add Mistral AI Chat Completion support to Inference Plugin #128538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Mistral AI Chat Completion support to Inference Plugin #128538
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good. I left a few suggestions.
Could you update the description of the PR so that the example requests are formatted?
Let's also wrap them in code blocks using the three backticks. Thanks for making them collapsable sections though!
builder.startObject(STREAM_OPTIONS_FIELD); | ||
builder.field(INCLUDE_USAGE_FIELD, true); | ||
builder.endObject(); | ||
fillStreamOptionsFields(builder); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just an FYI we have some inflight changes that'll affect how we do this: #128592
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the heads up! I inspected the changes in the attached PR. They don't affect my changes. We're good.
docs/changelog/128538.yaml
Outdated
@@ -0,0 +1,5 @@ | |||
pr: 128538 | |||
summary: "[ML] Add Mistral Chat Completion support to the Inference Plugin" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
summary: "[ML] Add Mistral Chat Completion support to the Inference Plugin" | |
summary: "Added Mistral Chat Completion support to the Inference Plugin" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
action.execute(inputs, timeout, listener); | ||
} else { | ||
listener.onFailure(createInvalidModelException(model)); | ||
switch (model) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a reminder that I believe this will cause failure when backporting to 8.19 because of the JDK version difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I forgot about JDK version. Fixed now.
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) | ||
.createParser(XContentParserConfiguration.EMPTY, response.body()) | ||
) { | ||
var responseMap = jsonParser.map(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're going to eventually refactor the error parsing logic to not parse at all in the future (across all the services). How about we convert the bytes to a utf8 string and return that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Also renamed it to MistralErrorResponse because it isn't really an entity. Do let me know if you prefer MistralErrorResponseEntity.
@@ -86,11 +88,21 @@ protected void checkForFailureStatusCode(Request request, HttpResult result) thr | |||
throw new RetryException(false, buildError(AUTHENTICATION, request, result)); | |||
} else if (statusCode >= 300 && statusCode < 400) { | |||
throw new RetryException(false, buildError(REDIRECTION, request, result)); | |||
} else if (statusCode == 422) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be confusing looking at this when the openai docs don't mention these error codes. How about we do this:
Let's push the successful status code check up into the base class (let's do that in a separate PR)
if (result.isSuccessfulResponse()) {
return;
}
I briefly looked at all the response handlers and they're duplicating those lines anyway. If the call in the base class succeeds we can skip calling checkForFailureStatusCode
for the child classes.
I think we can then override this method in the child class and add these error code checks after calling super.checkForFailureStatusCode()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That approach will lead to duplication because while this check will have to be present in both streaming and non-streaming handler classes. Streaming handler class has this hierarchy:
MistralUnifiedChatCompletionResponseHandler
extends
OpenAiUnifiedChatCompletionResponseHandler
extends
OpenAiChatCompletionResponseHandler
extends
OpenAiResponseHandler
Non-streaming handler has this hierarchy:
MistralCompletionResponseHandler
extends
OpenAiChatCompletionResponseHandler
extends
OpenAiResponseHandler
Closest common ancestor is OpenAiChatCompletionResponseHandler, but streaming handler also needs to take in logic from OpenAiUnifiedChatCompletionResponseHandler.
Do we really want to introduce such duplication? It will be literally 1 to 1 because errors are the same. Not very good practice. Please correct me if i'm wrong.
I think we can then override this method in the child class and add these error code checks after calling super.checkForFailureStatusCode()
Also I'd do the call to super in the default branch of the the check in child class because superclass always throws an exception, defaulting in undefined error, while child class check can default to calling super method.
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponseEntity; | ||
|
||
/** | ||
* Handles non-streaming chat completion responses for Mistral models, extending the OpenAI chat completion response handler. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Handles non-streaming chat completion responses for Mistral models, extending the OpenAI chat completion response handler. | |
* Handles non-streaming completion responses for Mistral models, extending the OpenAI chat completion response handler. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thanks.
@@ -58,6 +58,20 @@ public MistralEmbeddingsModel(MistralEmbeddingsModel model, MistralEmbeddingsSer | |||
setPropertiesFromServiceSettings(serviceSettings); | |||
} | |||
|
|||
protected void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be protected
? Can we make it private? I think we could potentially leak the this
context if this class is extended.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't see a reason for it to be protected. Changed to private. Same as for completions.
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.services.openai; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this class to the mistral
package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Moved it to appropriate package.
@@ -0,0 +1,51 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@prwhelan any suggestions on how to intentionally encounter a midstream error while testing?
* Handles streaming chat completion responses and error parsing for Mistral inference endpoints. | ||
* Adapts the OpenAI handler to support Mistral's simpler error schema with fields like "message" and "http_status_code". | ||
*/ | ||
public class MistralUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the midstream errors are in the same format as openai, how about we refactor the OpenAiUnifiedChatCompletionResponseHandler
so that we can replace some of the strings that reference openai
specifically? I think it's just the name of the parser:
Lines 115 to 120 in 2830768
"open_ai_error", | |
true, | |
args -> Optional.ofNullable((OpenAiErrorResponse) args[0]) | |
); | |
private static final ConstructingObjectParser<OpenAiErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>( | |
"open_ai_error", |
Maybe we could extract those classes and rename them to be more generic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
But we're not positive that midstream errors are in the same format as openai. It just assumed of the fact that Mistral uses OpenAI type API
…mistral-chat-completion-integration # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
Hi @jonathan-buttner Also I updated the section related to testing. Three backticks didn't work for me when I was creating initial PR comment, but when I applied them afterwards - worked perfectly. Not sure why, but I will remember for the future PR creation. |
New error format: Create Completion EndpointNot Found:
Unauthorized:
Invalid Model:
Perform Non-Streaming CompletionNot Found:
Perform Streaming CompletionNot Found:
Create Chat Completion EndpointNot Found:
Unauthorized:
Invalid Model:
Perform Streaming Chat CompletionNot Found:
Negative Max Tokens:
Invalid Model:
|
Change to existing Mistral AI provider integration allowing completion (both streaming and non-streaming) and chat_completion (only streaming) to be executed as part of inference API.
Changes were tested against next models:
Notes:
Examples of RQ/RS from local testing:
Create Completion Endpoint
Success:
Unauthorized:
Not Found:
Invalid Model:
Perform Non-Streaming Completion
Success:
Not Found:
Perform Streaming Completion
Success:
Not Found:
Create Chat Completion Endpoint
Success:
Unauthorized:
Not Found:
Invalid Model:
Perform Streaming Chat Completion
Success:
Invalid Model:
Not Found:
gradle check
?