Skip to content

Commit 032cbb4

Browse files
hahaQWQhahaqwq
and
hahaqwq
authored
feat: embeddings model support (#86)
* feat: embeddings model support * fix: cli embed * fix: server embedding fix * fix: config save * fix: type error * style: embeddings * feat: embedding_models schema * feat: update message embedding functions to include chatId and improve embedding table retrieval * feat: add embedding_models table and update embedding functions to include chatId for improved retrieval * feat: enhance search functionality with TypeScript types and improved result handling - Added TypeScript interfaces for Command and CommandMetadata to improve type safety. - Updated the search function to return a structured SearchResponse. - Introduced a new types file for SearchResult and SearchResponse to standardize search-related data structures. * fix: improve pagination and configuration handling in search and settings components - Updated pagination logic in search.vue to ensure total is a valid number before comparison with pageSize. - Ensured apiId in settings.vue is converted to a string before saving the configuration. - Removed unused imports in search.ts to clean up the codebase. - Adjusted parameters in findSimilarMessages function to include targetChatId for better search results. * feat: add migration for embedding tables and related triggers - Implemented migration for creating embedding_models table with necessary fields and triggers for automatic timestamp updates. - Added logic to create corresponding embedding tables for existing message tables, including necessary constraints and indexes. - Enhanced logging for better tracking of migration progress and errors. --------- Co-authored-by: hahaqwq <[email protected]>
1 parent fec4a4c commit 032cbb4

File tree

31 files changed

+985
-85
lines changed

31 files changed

+985
-85
lines changed

.cursorrules

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
- Code always in English.
2+
- Step by step thinking.
3+
- Always include TypeScript type annotations when generating TS code.
4+
- Follow good engineering practices: SOLID principles, functional programming, configurability, to ensure clear code structure.
5+
- Comments should focus on explaining "why" rather than "how."
6+
7+
「lets not answer this question too quickly, we can think in depth, and i am new to this area, please correct me wrong if there is any」
8+
9+
!!! If the user asks you in another language, explain in that language.

apps/cli/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
},
2323
"devDependencies": {
2424
"@types/jsdom": "^21.1.7",
25-
"@types/node": "^22.13.9",
25+
"@types/node": "^22.13.10",
2626
"tsx": "^4.19.3"
2727
}
2828
}

apps/cli/src/commands/embed.ts

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import * as input from '@inquirer/prompts'
22
import { useLogger } from '@tg-search/common'
33
import { EmbeddingService } from '@tg-search/core'
4-
import { findMessagesByChatId, updateMessageEmbeddings } from '@tg-search/db'
4+
import { findMessageMissingEmbed, updateMessageEmbeddings, useEmbeddingTable } from '@tg-search/db'
55

66
import { TelegramCommand } from '../command'
77

88
const logger = useLogger()
9-
109
interface EmbedOptions {
1110
batchSize?: number
1211
chatId?: number
@@ -48,12 +47,13 @@ export class EmbedCommand extends TelegramCommand {
4847

4948
// Initialize embedding service
5049
const embedding = new EmbeddingService()
51-
50+
await useEmbeddingTable(Number(chatId), embedding.getEmbeddingConfig())
5251
try {
5352
// Get all messages for the chat
54-
const messages = await findMessagesByChatId(Number(chatId))
55-
const messagesToEmbed = messages.items.filter(m => !m.embedding && m.content)
56-
const totalMessages = messagesToEmbed.length
53+
const messages = await findMessageMissingEmbed(Number(chatId), embedding.getEmbeddingConfig())
54+
logger.debug(`共有 ${messages.length} 条消息需要处理`)
55+
56+
const totalMessages = messages.length
5757

5858
logger.log(`共有 ${totalMessages} 条消息需要处理`)
5959

@@ -62,25 +62,27 @@ export class EmbedCommand extends TelegramCommand {
6262
let failedEmbeddings = 0
6363

6464
// Split messages into batches
65-
for (let i = 0; i < messagesToEmbed.length; i += Number(batchSize)) {
66-
const batch = messagesToEmbed.slice(i, i + Number(batchSize))
65+
for (let i = 0; i < messages.length; i += Number(batchSize)) {
66+
const batch = messages.slice(i, i + Number(batchSize))
6767
logger.debug(`处理第 ${i + 1}${i + batch.length} 条消息`)
6868

6969
// Generate embeddings in parallel
7070
const contents = batch.map(m => m.content!)
7171
const embeddings = await embedding.generateEmbeddings(contents)
7272

7373
// Prepare updates
74-
const updates = batch.map((message, index) => ({
75-
id: message.id,
76-
embedding: embeddings[index],
77-
}))
74+
const updates = batch
75+
.filter(message => message.uuid)
76+
.map((message, index) => ({
77+
id: message.uuid!,
78+
embedding: embeddings[index],
79+
}))
7880

7981
try {
8082
// Update embeddings in batches with concurrency control
8183
for (let j = 0; j < updates.length; j += Number(concurrency)) {
8284
const concurrentBatch = updates.slice(j, j + Number(concurrency))
83-
await updateMessageEmbeddings(Number(chatId), concurrentBatch)
85+
await updateMessageEmbeddings(Number(chatId), concurrentBatch, embedding.getEmbeddingConfig())
8486
totalProcessed += concurrentBatch.length
8587
}
8688
}

apps/cli/src/commands/search.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import type { DatabaseMessageType, SearchOptions } from '@tg-search/db'
33
import * as input from '@inquirer/prompts'
44
import { useLogger } from '@tg-search/common'
55
import { EmbeddingService } from '@tg-search/core'
6-
import { findSimilarMessages, getAllFolders, getChatsInFolder } from '@tg-search/db'
6+
import { findSimilarMessages, getAllFolders, getChatsInFolder, useEmbeddingTable } from '@tg-search/db'
77

88
import { TelegramCommand } from '../command'
99

@@ -108,6 +108,7 @@ export class SearchCommand extends TelegramCommand {
108108

109109
// Initialize embedding service
110110
const embedding = new EmbeddingService()
111+
await useEmbeddingTable(Number(chatId), embedding.getEmbeddingConfig())
111112

112113
// Generate embedding for query
113114
logger.log('正在生成向量嵌入...')
@@ -124,7 +125,7 @@ export class SearchCommand extends TelegramCommand {
124125
limit: Number(limit),
125126
}
126127

127-
const results = await findSimilarMessages(queryEmbedding[0], options)
128+
const results = await findSimilarMessages(Number(chatId), queryEmbedding[0], embedding.getEmbeddingConfig(), options)
128129
logger.log(`找到 ${results.length} 条相关消息:\n`)
129130

130131
for (const message of results) {

apps/frontend/src/apis/commands/useSearch.ts

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
11
import type { SearchRequest } from '@tg-search/server'
2+
import type { SearchResponse, SearchResult } from '../../types/search'
23

34
import { computed, ref } from 'vue'
45

56
import { useCommandHandler } from '../../composables/useCommands'
67

8+
interface CommandMetadata {
9+
command: string
10+
duration?: number
11+
total?: number
12+
results?: SearchResult[]
13+
}
14+
15+
interface Command {
16+
id: string
17+
type: string
18+
status: 'pending' | 'completed' | 'failed'
19+
progress: number
20+
message: string
21+
error?: Error
22+
metadata?: CommandMetadata
23+
}
24+
725
/**
826
* Search composable for managing search state and functionality
927
*/
@@ -33,9 +51,8 @@ export function useSearch() {
3351

3452
// Computed results from current command
3553
const results = computed(() => {
36-
if (!currentCommand.value?.metadata?.results)
37-
return []
38-
return currentCommand.value.metadata.results
54+
const commandResults = currentCommand.value?.metadata?.results
55+
return commandResults ? commandResults as SearchResult[] : []
3956
})
4057

4158
const total = computed(() => {
@@ -45,7 +62,7 @@ export function useSearch() {
4562
/**
4663
* Execute search with current parameters
4764
*/
48-
async function search(params?: Partial<SearchRequest>) {
65+
async function search(params?: Partial<SearchRequest>): Promise<SearchResponse> {
4966
if (!query.value.trim() && !params?.query) {
5067
return { success: false, error: new Error('搜索词不能为空') }
5168
}
@@ -65,7 +82,15 @@ export function useSearch() {
6582
useVectorSearch: useVectorSearch.value,
6683
}
6784

68-
return executeCommand(searchParams)
85+
const rawResult = await executeCommand(searchParams)
86+
const result = rawResult as unknown as Command
87+
88+
return {
89+
success: result.status === 'completed',
90+
error: result.error,
91+
total: result.metadata?.total,
92+
results: result.metadata?.results,
93+
}
6994
}
7095

7196
/**

apps/frontend/src/pages/search.vue

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,10 @@ function formatDate(date: string | Date): string {
202202
</div>
203203

204204
<!-- Pagination -->
205-
<div v-if="total > pageSize" class="mt-8 flex justify-center">
205+
<div v-if="total && Number(total) > pageSize" class="mt-8 flex justify-center">
206206
<nav class="flex items-center gap-2">
207207
<button
208-
v-for="page in Math.ceil(total / pageSize)"
208+
v-for="page in Math.ceil(Number(total) / pageSize)"
209209
:key="page"
210210
class="rounded-lg px-3 py-1"
211211
:class="{

apps/frontend/src/pages/settings.vue

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,14 @@ loadConfig()
397397
class="mt-1 block w-full border border-gray-300 rounded-md px-3 py-2 dark:border-gray-700 dark:bg-gray-800 dark:text-white"
398398
>
399399
</div>
400+
<div>
401+
<label class="block text-sm text-gray-700 font-medium dark:text-gray-300">{{ $t('pages.settings.dimensions') }}</label>
402+
<input
403+
v-model="config.api.embedding.dimensions"
404+
:disabled="!isEditing"
405+
class="mt-1 block w-full border border-gray-300 rounded-md px-3 py-2 dark:border-gray-700 dark:bg-gray-800 dark:text-white"
406+
>
407+
</div>
400408
<div>
401409
<label class="block text-sm text-gray-700 font-medium dark:text-gray-300">{{ $t('pages.settings.api_key') }}</label>
402410
<input

apps/frontend/src/types/search.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
export interface SearchResult {
2+
id: number
3+
chatId: number
4+
fromId?: number
5+
fromName?: string
6+
type: string
7+
content: string
8+
createdAt: string
9+
score: number
10+
}
11+
12+
export interface SearchResponse {
13+
success: boolean
14+
error?: Error
15+
total?: number
16+
results?: SearchResult[]
17+
}

apps/server/src/routes/config.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ const configSchema = z.object({
4242
embedding: z.object({
4343
provider: z.enum(['openai', 'ollama']),
4444
model: z.string(),
45+
dimensions: z.number(),
4546
apiKey: z.string().optional(),
4647
apiBase: z.string().optional(),
4748
}),

apps/server/src/routes/search.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ export function setupSearchRoutes(app: App) {
2323
const validatedBody = searchCommandSchema.parse(body)
2424

2525
logger.withFields(validatedBody).debug('Search request received')
26-
2726
return createSSEResponse(async (controller) => {
2827
await commandManager.executeCommand('search', null, validatedBody, controller)
2928
})

apps/server/src/services/commands/embed.ts

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import type { EmbedDetails } from '../../types/apis/embed'
44

55
import { useLogger } from '@tg-search/common'
66
import { EmbeddingService } from '@tg-search/core'
7-
import { findMessagesByChatId, updateMessageEmbeddings } from '@tg-search/db'
7+
import { findMessageMissingEmbed, updateMessageEmbeddings, useEmbeddingTable } from '@tg-search/db'
88
import { z } from 'zod'
99

1010
const logger = useLogger()
@@ -48,6 +48,7 @@ export class EmbedCommandHandler {
4848

4949
// Initialize embedding service
5050
const embedding = new EmbeddingService()
51+
await useEmbeddingTable(chatId, embedding.getEmbeddingConfig())
5152
const command: EmbedCommand = {
5253
id: crypto.randomUUID(),
5354
type: 'sync',
@@ -65,9 +66,8 @@ export class EmbedCommandHandler {
6566

6667
try {
6768
// Get all messages for the chat
68-
const messages = await findMessagesByChatId(chatId)
69-
const messagesToEmbed = messages.items.filter(m => !m.embedding && m.content)
70-
const totalMessages = messagesToEmbed.length
69+
const messages = await findMessageMissingEmbed(chatId, embedding.getEmbeddingConfig())
70+
const totalMessages = messages.length
7171
const totalBatches = Math.ceil(totalMessages / batchSize)
7272

7373
command.metadata.totalMessages = totalMessages
@@ -83,29 +83,33 @@ export class EmbedCommandHandler {
8383
let failedBatches = 0
8484

8585
// Split messages into batches
86-
for (let i = 0; i < messagesToEmbed.length; i += batchSize) {
86+
for (let i = 0; i < messages.length; i += batchSize) {
8787
const currentBatch = Math.floor(i / batchSize) + 1
8888
command.metadata.currentBatch = currentBatch
8989

90-
const batch = messagesToEmbed.slice(i, i + batchSize)
90+
const batch = messages.slice(i, i + batchSize)
9191
logger.debug(`Processing batch ${currentBatch}/${totalBatches} (${batch.length} messages)`)
9292

9393
try {
9494
// Generate embeddings in parallel
95-
const contents = batch.map(m => m.content!)
95+
const contents = batch
96+
.filter(m => m.content)
97+
.map(m => m.content as string)
9698
const embeddings = await embedding.generateEmbeddings(contents)
9799

98100
// Prepare updates
99-
const updates = batch.map((message, index) => ({
100-
id: message.id,
101-
embedding: embeddings[index],
102-
}))
101+
const updates = batch
102+
.filter(message => message.uuid)
103+
.map((message, index) => ({
104+
id: message.uuid!,
105+
embedding: embeddings[index],
106+
}))
103107

104108
// Update embeddings in batches with concurrency control
105109
for (let j = 0; j < updates.length; j += concurrency) {
106110
const concurrentBatch = updates.slice(j, j + concurrency)
107111
try {
108-
await updateMessageEmbeddings(chatId, concurrentBatch)
112+
await updateMessageEmbeddings(chatId, concurrentBatch, embedding.getEmbeddingConfig())
109113
totalProcessed += concurrentBatch.length
110114
command.metadata.processedMessages = totalProcessed
111115
}

apps/server/src/services/commands/search.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ export class SearchCommandHandler {
6969
// Vector search
7070
const embedding = new EmbeddingService()
7171
const queryEmbedding = await embedding.generateEmbedding(params.query)
72-
const results = await findSimilarMessages(queryEmbedding, {
72+
const results = await findSimilarMessages(targetChatId || 0, queryEmbedding, embedding.getEmbeddingConfig(), {
7373
chatId: targetChatId || 0,
7474
limit: params.limit * 2,
7575
offset: params.offset,

config/config.example.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,6 @@ api:
5858
model: text-embedding-3-small
5959
# API key for provider
6060
apiKey: your_openai_api_key
61+
dimensions: 1536
6162
# Optional, for custom API providers
6263
apiBase: 'https://api.openai.com/v1'

packages/common/src/composable/config.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ const DEFAULT_CONFIG = {
9090
apiKey: '',
9191
// Optional API base URL
9292
apiBase: '',
93+
// Embedding dimensions
94+
dimensions: 1536,
9395
},
9496
},
9597
} as const satisfies Config

packages/common/src/types/config.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ export interface ApiConfig {
4545
embedding: {
4646
provider: 'ollama' | 'openai'
4747
model: string
48+
dimensions: number
4849
apiKey?: string
4950
apiBase?: string
5051
}

packages/core/src/services/embedding.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { EmbeddingTableConfig } from '@tg-search/db'
12
import type { EmbeddingModelConfig } from '../types'
23

34
import { getConfig, useLogger } from '@tg-search/common'
@@ -7,14 +8,26 @@ import { EmbeddingAdapter } from '../adapter/embedding/adapter'
78
* Service for generating embeddings from text
89
*/
910
export class EmbeddingService {
10-
private embedding
11+
private embedding: EmbeddingAdapter
12+
private embedding_config: EmbeddingTableConfig
1113
private logger = useLogger()
1214
private config = getConfig()
1315

1416
constructor() {
1517
this.embedding = new EmbeddingAdapter({
1618
...this.config.api.embedding,
1719
} as EmbeddingModelConfig)
20+
this.embedding_config = {
21+
...this.config.api.embedding,
22+
dimensions: this.config.api.embedding.dimensions,
23+
} as EmbeddingTableConfig
24+
}
25+
26+
/**
27+
* 获取模型配置
28+
*/
29+
getEmbeddingConfig(): EmbeddingTableConfig {
30+
return this.embedding_config
1831
}
1932

2033
/**

packages/core/src/services/export.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ export class ExportService {
202202
let count = 0
203203
let failedCount = 0
204204
let messages: TelegramMessage[] = []
205-
205+
logger.debug(`history ${JSON.stringify(history)}`)
206206
const total = limit || history.count - 1 || 100
207207
function isSkipMedia(type: DatabaseMessageType) {
208208
return !messageTypes.includes(type)
@@ -229,7 +229,7 @@ export class ExportService {
229229
}
230230
try {
231231
// Try to export messages
232-
for await (const message of this.client.getMessages(chatId, undefined, {
232+
for await (const message of this.client.getMessages(chatId, total, {
233233
skipMedia: isSkipMedia('photo') || isSkipMedia('video') || isSkipMedia('document') || isSkipMedia('sticker'),
234234
startTime,
235235
endTime,

0 commit comments

Comments
 (0)