Skip to content

♻️ refactor(google-genai): migrate from @google/generative-ai to @google/genai #7954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@
"@codesandbox/sandpack-react": "^2.20.0",
"@cyntler/react-doc-viewer": "^1.17.0",
"@electric-sql/pglite": "0.2.17",
"@google-cloud/vertexai": "^1.10.0",
"@google/generative-ai": "^0.24.1",
"@google/genai": "^1.0.1",
"@huggingface/inference": "^2.8.1",
"@icons-pack/react-simple-icons": "9.6.0",
"@khmyznikov/pwa-install": "0.3.9",
Expand Down
24 changes: 12 additions & 12 deletions src/libs/model-runtime/google/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// @vitest-environment edge-runtime
import { FunctionDeclarationsTool } from '@google/generative-ai';
import { Tool as FunctionDeclarationsTool, GoogleGenAI } from '@google/genai';
import OpenAI from 'openai';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

Expand All @@ -22,7 +22,7 @@
instance = new LobeGoogleAI({ apiKey: 'test' });

// 使用 vi.spyOn 来模拟 chat.completions.create 方法
vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockResolvedValue(new ReadableStream()),
} as any);
});
Expand Down Expand Up @@ -60,7 +60,7 @@
controller.close();
},
});
vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockResolvedValueOnce(mockStream),
} as any);

Expand Down Expand Up @@ -208,7 +208,7 @@
},
});

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockResolvedValueOnce(mockStream),
} as any);
});
Expand All @@ -224,7 +224,7 @@
controller.close();
},
});
vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockResolvedValueOnce(mockStream),
} as any);
const debugStreamSpy = vi
Expand All @@ -250,7 +250,7 @@

const apiError = new Error(message);

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(apiError),
} as any);

Expand All @@ -271,7 +271,7 @@

const apiError = new Error(message);

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(apiError),
} as any);

Expand All @@ -292,7 +292,7 @@

const apiError = new Error(message);

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(apiError),
} as any);

Expand Down Expand Up @@ -326,7 +326,7 @@

const apiError = new Error(message);

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(apiError),
} as any);

Expand Down Expand Up @@ -354,7 +354,7 @@
const apiError = new Error('Error message');

// 使用 vi.spyOn 来模拟 chat.completions.create 方法
vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(apiError),
} as any);

Expand Down Expand Up @@ -392,7 +392,7 @@
};
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(apiError),
} as any);

Expand All @@ -418,7 +418,7 @@
// Arrange
const genericError = new Error('Generic Error');

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
vi.spyOn(instance['client'].models, 'generateContentStream').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(genericError),
} as any);

Expand Down Expand Up @@ -590,7 +590,7 @@
const googleTools = instance['buildGoogleTools'](tools);

expect(googleTools).toHaveLength(1);
expect((googleTools![0] as FunctionDeclarationsTool).functionDeclarations![0]).toEqual({

Check failure on line 593 in src/libs/model-runtime/google/index.test.ts

View workflow job for this annotation

GitHub Actions / test

src/libs/model-runtime/google/index.test.ts > LobeGoogleAI > private method > buildGoogleTools > should correctly convert ChatCompletionTool to GoogleFunctionCallTool

AssertionError: expected { description: 'A test tool', …(2) } to deeply equal { name: 'testTool', …(2) } - Expected + Received @@ -1,9 +1,10 @@ { "description": "A test tool", "name": "testTool", "parameters": { + "description": undefined, "properties": { "param1": { "type": "string", }, "param2": { @@ -11,8 +12,8 @@ }, }, "required": [ "param1", ], - "type": "object", + "type": "OBJECT", }, } ❯ src/libs/model-runtime/google/index.test.ts:593:88
name: 'testTool',
description: 'A test tool',
parameters: {
Expand Down
96 changes: 44 additions & 52 deletions src/libs/model-runtime/google/index.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import type { VertexAI } from '@google-cloud/vertexai';
import {
Content,
FunctionCallPart,
FunctionDeclaration,
Tool as GoogleFunctionCallTool,
GoogleGenerativeAI,
GoogleSearchRetrievalTool,
GoogleGenAI,
Part,
SchemaType,
} from '@google/generative-ai';
Type as SchemaType,
} from '@google/genai';

import type { ChatModelCard } from '@/types/llm';
import { imageUrlToBase64 } from '@/utils/imageToBase64';
Expand Down Expand Up @@ -41,7 +38,7 @@ const modelsWithModalities = new Set([
'gemini-2.0-flash-preview-image-generation',
]);

const modelsDisableInstuction = new Set([
const modelsDisableInstruction = new Set([
'gemini-2.0-flash-exp',
'gemini-2.0-flash-exp-image-generation',
'gemini-2.0-flash-preview-image-generation',
Expand Down Expand Up @@ -77,7 +74,7 @@ const DEFAULT_BASE_URL = 'https://generativelanguage.googleapis.com';
interface LobeGoogleAIParams {
apiKey?: string;
baseURL?: string;
client?: GoogleGenerativeAI | VertexAI;
client?: GoogleGenAI;
id?: string;
isVertexAi?: boolean;
}
Expand All @@ -87,7 +84,7 @@ interface GoogleAIThinkingConfig {
}

export class LobeGoogleAI implements LobeRuntimeAI {
private client: GoogleGenerativeAI;
private client: GoogleGenAI;
private isVertexAi: boolean;
baseURL?: string;
apiKey?: string;
Expand All @@ -96,9 +93,10 @@ export class LobeGoogleAI implements LobeRuntimeAI {
constructor({ apiKey, baseURL, client, isVertexAi, id }: LobeGoogleAIParams = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);

this.client = new GoogleGenerativeAI(apiKey);
this.apiKey = apiKey;
this.client = client ? (client as GoogleGenerativeAI) : new GoogleGenerativeAI(apiKey);
this.client = client
? (client as GoogleGenAI)
: new GoogleGenAI({ apiKey, vertexai: isVertexAi });
this.baseURL = client ? undefined : baseURL || DEFAULT_BASE_URL;
this.isVertexAi = isVertexAi || false;

Expand All @@ -117,50 +115,44 @@ export class LobeGoogleAI implements LobeRuntimeAI {
const contents = await this.buildGoogleMessages(payload.messages);

const inputStartAt = Date.now();
const geminiStreamResult = await this.client
.getGenerativeModel(
{
generationConfig: {
maxOutputTokens: payload.max_tokens,
// @ts-expect-error - Google SDK 0.24.0 doesn't have this property for now with
response_modalities: modelsWithModalities.has(model) ? ['Text', 'Image'] : undefined,
temperature: payload.temperature,
thinkingConfig,
topP: payload.top_p,
const geminiStreamResult = await this.client.models.generateContentStream({
config: {
maxOutputTokens: payload.max_tokens,
responseModalities: modelsWithModalities.has(model) ? ['Text', 'Image'] : undefined,

// avoid wide sensitive words
// refs: https://github.com/lobehub/lobe-chat/pull/1418
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: getThreshold(model),
},
model,
// avoid wide sensitive words
// refs: https://github.com/lobehub/lobe-chat/pull/1418
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: getThreshold(model),
},
],
},
{ apiVersion: 'v1beta', baseUrl: this.baseURL },
)
.generateContentStream({
contents,
systemInstruction: modelsDisableInstuction.has(model)
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: getThreshold(model),
},
],
systemInstruction: modelsDisableInstruction.has(model)
? undefined
: (payload.system as string),
temperature: payload.temperature,
thinkingConfig,
tools: this.buildGoogleTools(payload.tools, payload),
});
topP: payload.top_p,
},
contents,
model,
});

const googleStream = convertIterableToStream(geminiStreamResult.stream);
const googleStream = convertIterableToStream(geminiStreamResult);
const [prod, useForDebug] = googleStream.tee();

const key = this.isVertexAi
Expand Down Expand Up @@ -291,7 +283,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
const content = message.content as string | UserMessageContentPart[];
if (!!message.tool_calls) {
return {
parts: message.tool_calls.map<FunctionCallPart>((tool) => ({
parts: message.tool_calls.map<Part>((tool) => ({
functionCall: {
args: safeParseJSON(tool.function.arguments)!,
name: tool.function.name,
Expand Down Expand Up @@ -384,7 +376,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
return; // 若历史消息中已有 function calling,则不再注入任何 Tools
}
if (payload?.enabledSearch) {
return [{ googleSearch: {} } as GoogleSearchRetrievalTool];
return [{ googleSearch: {} } as GoogleFunctionCallTool];
}

if (!tools || tools.length === 0) return;
Expand Down
30 changes: 17 additions & 13 deletions src/libs/model-runtime/utils/streams/google-ai.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { EnhancedGenerateContentResponse } from '@google/generative-ai';
import { GenerateContentResponse } from '@google/genai';
import { describe, expect, it, vi } from 'vitest';

import * as uuidModule from '@/utils/uuid';
Expand All @@ -11,10 +11,12 @@ describe('GoogleGenerativeAIStream', () => {

const mockGenerateContentResponse = (text: string, functionCalls?: any[]) =>
({
text: () => text,
functionCall: () => functionCalls?.[0],
functionCalls: () => functionCalls,
}) as EnhancedGenerateContentResponse;
text,
data: undefined,
executableCode: undefined,
codeExecutionResult: undefined,
functionCalls: functionCalls,
}) as GenerateContentResponse;

const mockGoogleStream = new ReadableStream({
start(controller) {
Expand Down Expand Up @@ -116,10 +118,12 @@ describe('GoogleGenerativeAIStream', () => {
};
const mockGenerateContentResponse = (text: string, functionCalls?: any[]) =>
({
text: () => text,
functionCall: () => functionCalls?.[0],
functionCalls: () => functionCalls,
}) as EnhancedGenerateContentResponse;
text: text,
data: undefined,
executableCode: undefined,
codeExecutionResult: undefined,
functionCalls: functionCalls,
}) as GenerateContentResponse;

const mockGoogleStream = new ReadableStream({
start(controller) {
Expand Down Expand Up @@ -209,7 +213,7 @@ describe('GoogleGenerativeAIStream', () => {
],
},
],
text: () => '234',
text: '234',
usageMetadata: {
promptTokenCount: 20,
totalTokenCount: 20,
Expand All @@ -218,7 +222,7 @@ describe('GoogleGenerativeAIStream', () => {
modelVersion: 'gemini-2.0-flash-exp-image-generation',
},
{
text: () => '567890\n',
text: '567890\n',
candidates: [
{
content: { parts: [{ text: '567890\n' }], role: 'model' },
Expand Down Expand Up @@ -299,7 +303,7 @@ describe('GoogleGenerativeAIStream', () => {
],
},
],
text: () => '234',
text: '234',
usageMetadata: {
promptTokenCount: 19,
candidatesTokenCount: 3,
Expand All @@ -310,7 +314,7 @@ describe('GoogleGenerativeAIStream', () => {
modelVersion: 'gemini-2.0-flash-exp-image-generation',
},
{
text: () => '567890\n',
text: '567890\n',
candidates: [
{
content: { parts: [{ text: '567890\n' }], role: 'model' },
Expand Down
12 changes: 6 additions & 6 deletions src/libs/model-runtime/utils/streams/google-ai.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { EnhancedGenerateContentResponse } from '@google/generative-ai';
import { GenerateContentResponse } from '@google/genai';

import { ModelTokensUsage } from '@/types/message';
import { GroundingSearch } from '@/types/search';
Expand All @@ -16,7 +16,7 @@ import {
} from './protocol';

const transformGoogleGenerativeAIStream = (
chunk: EnhancedGenerateContentResponse,
chunk: GenerateContentResponse,
context: StreamContext,
): StreamProtocolChunk | StreamProtocolChunk[] => {
// maybe need another structure to add support for multiple choices
Expand Down Expand Up @@ -50,7 +50,7 @@ const transformGoogleGenerativeAIStream = (
);
}

const functionCalls = chunk.functionCalls?.();
const functionCalls = chunk.functionCalls;

if (functionCalls) {
return [
Expand All @@ -73,9 +73,9 @@ const transformGoogleGenerativeAIStream = (
];
}

const text = chunk.text?.();
const text = chunk.text;

if (candidate) {
if (candidate && candidate.content) {
// 首先检查是否为 reasoning 内容 (thought: true)
if (Array.isArray(candidate.content.parts) && candidate.content.parts.length > 0) {
for (const part of candidate.content.parts) {
Expand Down Expand Up @@ -148,7 +148,7 @@ export interface GoogleAIStreamOptions {
}

export const GoogleGenerativeAIStream = (
rawStream: ReadableStream<EnhancedGenerateContentResponse>,
rawStream: ReadableStream<GenerateContentResponse>,
{ callbacks, inputStartAt }: GoogleAIStreamOptions = {},
) => {
const streamStack: StreamContext = { id: 'chat_' + nanoid() };
Expand Down
Loading
Loading