Rerank重排序与辅助工具
前七篇讲完了 Chat 子系统(六篇)和 Embedding 子系统(一篇),覆盖了 infra-ai 模块 9 个包中的 6 个:config(配置)、enums(枚举)、model(路由核心)、http(HTTP 基础设施)、chat(对话)、embedding(向量化)。
这一篇讲剩下的 3 个包:rerank(重排序子系统)、token(Token 估算)、util(工具类),把整个 infra-ai 模块收尾。
Rerank 是三种能力中的最后一个——和 Chat、Embedding 遵循同样的三层接口设计,复用同一套路由和熔断机制。但 Rerank 有两个独特之处:一是百炼的 Rerank 实现里有一套去重 + 回填的防御性逻辑,处理混合检索场景的边缘情况;二是 NoopRerankClient 用空对象模式实现了优雅降级——没有 Rerank 模型时系统不报错,而是回退到简单截断。
Token 估算和响应清洗是两个小工具,代码量不大,但在 RAG 流程中经常用到。一并在本篇讲完。
Rerank 重排序子系统
1. Rerank 在 RAG 中的角色
回顾一下 RAG 的检索阶段。用户提问“AirPods Pro 2 的保修期是多久?”,系统先通过向量检索(Bi-Encoder) 从 100 万个 Chunk 中快速召回 Top-20 候选。这 20 个候选覆盖面广(保证了召回率),但排序不一定准——向量相似度是粗略的语义匹配,和 query 的真正相关度之间有差距。
Rerank 的作用是对这 20 个候选做精排。它用 Cross-Encoder 模型逐对评估 query 和每个候选的相关度,给出精确的分数,返回最相关的 Top-5 喂给大模型生成回答。
两阶段策略的分工:粗排(向量检索)保证覆盖率——宁可多召回一些不太相关的,也不要漏掉真正相关的;精排(Rerank)保证准确率——从粗排结果中挑出真正最相关的。这个数据漏斗在第一篇的架构总览中提过:100 万文档 → 向量检索召回 20 → Rerank 精排 5 → 大模型生成 1 个答案。
2. 接口设计
2.1 RetrievedChunk 数据结构
Rerank 操作的数据单元是 RetrievedChunk——定义在 framework 模块,是跨层共享的数据结构:
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class RetrievedChunk {
/**
* 命中记录的唯一标识(向量库主键)
*/
private String id;
/**
* 命中的文本内容(Chunk 正文)
*/
private String text;
/**
* 命中得分(数值越大相关性越高)
*/
private Float score;
}
三个字段:id 是向量数据库中的主键,text 是 Chunk 的文本内容,score 是相关度分数。向量检索阶段,score 是向量相似度分数;Rerank 之后,score 会被更新为 Rerank 模型给出的 relevance_score。
RetrievedChunk 放在 framework 模块而不是 infra-ai 模块——因为检索层、Rerank 层、生成层都用到它,是跨模块共享的约定。
2.2 RerankService 业务层接口
public interface RerankService {
/**
* 对候选文档进行精排,按相关度重新排序,返回前 topN 条
*
* @param query 用户问题
* @param candidates 向量检索召回的候选文档
* @param topN 最终保留的条数
* @return 精排后的前 topN 条文档
*/
List<RetrievedChunk> rerank(String query, List<RetrievedChunk> candidates, int topN);
}
只有一个方法,语义很清晰:传入 query 和候选列表,返回精排后的 Top-N。和 LLMService、EmbeddingService 的设计理念一致——业务层只看到这个接口,不感知供应商、路由、熔断等 infra 细节。
2.3 RerankClient 供应商接口
public interface RerankClient {
String provider();
List<RetrievedChunk> rerank(String query, List<RetrievedChunk> candidates,
int topN, ModelTarget target);
}
多了一个 ModelTarget 参数——和 ChatClient、EmbeddingClient 同样的设计。provider() 返回供应商标识,路由服务通过它查找对应的客户端实例。
3. BaiLianRerankClient 完整代码
@Service
@Slf4j
@RequiredArgsConstructor
public class BaiLianRerankClient implements RerankClient {
private final OkHttpClient httpClient;
@Override
public String provider() {
return ModelProvider.BAI_LIAN.getId();
}
@Override
public List<RetrievedChunk> rerank(String query, List<RetrievedChunk> candidates,
int topN, ModelTarget target) {
if (candidates == null || candidates.isEmpty()) {
return List.of();
}
List<RetrievedChunk> dedup = new ArrayList<>(candidates.size());
Set<String> seen = new HashSet<>();
for (RetrievedChunk rc : candidates) {
if (seen.add(rc.getId())) {
dedup.add(rc);
}
}
if (topN <= 0 || dedup.size() <= topN) {
return dedup;
}
return doRerank(query, dedup, topN, target);
}
private List<RetrievedChunk> doRerank(String query, List<RetrievedChunk> candidates,
int topN, ModelTarget target) {
AIModelProperties.ProviderConfig provider =
HttpResponseHelper.requireProvider(target, provider());
if (candidates == null || candidates.isEmpty() || topN <= 0) {
return List.of();
}
JsonObject reqBody = new JsonObject();
reqBody.addProperty("model", HttpResponseHelper.requireModel(target, provider()));
JsonObject input = new JsonObject();
input.addProperty("query", query);
JsonArray documentsArray = new JsonArray();
for (RetrievedChunk each : candidates) {
documentsArray.add(each.getText() == null ? "" : each.getText());
}
input.add("documents", documentsArray);
JsonObject parameters = new JsonObject();
parameters.addProperty("top_n", topN);
parameters.addProperty("return_documents", true);
reqBody.add("input", input);
reqBody.add("parameters", parameters);
Request request = new Request.Builder()
.url(ModelUrlResolver.resolveUrl(
provider, target.candidate(), ModelCapability.RERANK))
.post(RequestBody.create(reqBody.toString(), HttpMediaTypes.JSON))
.addHeader("Authorization", "Bearer " + provider.getApiKey())
.build();
JsonObject respJson;
try (Response response = httpClient.newCall(request).execute()) {
if (!response.isSuccessful()) {
String body = HttpResponseHelper.readBody(response.body());
log.warn("{} rerank 请求失败: status={}, body={}",
provider(), response.code(), body);
throw new ModelClientException(
provider() + " rerank 请求失败: HTTP " + response.code(),
ModelClientErrorType.fromHttpStatus(response.code()),
response.code()
);
}
respJson = HttpResponseHelper.parseJson(response.body(), provider());
} catch (IOException e) {
throw new ModelClientException(
provider() + " rerank 请求失败: " + e.getMessage(),
ModelClientErrorType.NETWORK_ERROR, null, e);
}
JsonObject output = requireOutput(respJson);
JsonArray results = output.getAsJsonArray("results");
if (CollUtil.isEmpty(results)) {
throw new ModelClientException(
provider() + " rerank results 为空",
ModelClientErrorType.INVALID_RESPONSE, null);
}
List<RetrievedChunk> reranked = new ArrayList<>();
Set<String> addedIds = new HashSet<>();
for (JsonElement elem : results) {
if (!elem.isJsonObject()) {
continue;
}
JsonObject item = elem.getAsJsonObject();
if (!item.has("index")) {
continue;
}
int idx = item.get("index").getAsInt();
if (idx < 0 || idx >= candidates.size()) {
continue;
}
RetrievedChunk src = candidates.get(idx);
Float score = null;
if (item.has("relevance_score") && !item.get("relevance_score").isJsonNull()) {
score = item.get("relevance_score").getAsFloat();
}
RetrievedChunk hit = score != null
? new RetrievedChunk(src.getId(), src.getText(), score) : src;
reranked.add(hit);
addedIds.add(src.getId());
if (reranked.size() >= topN) {
break;
}
}
if (reranked.size() < topN) {
for (RetrievedChunk c : candidates) {
if (addedIds.add(c.getId())) {
reranked.add(c);
}
if (reranked.size() >= topN) {
break;
}
}
}
return reranked;
}
private JsonObject requireOutput(JsonObject respJson) {
if (respJson == null || !respJson.has("output")) {
throw new ModelClientException(
provider() + " rerank 响应缺少 output",
ModelClientErrorType.INVALID_RESPONSE, null);
}
JsonObject output = respJson.getAsJsonObject("output");
if (output == null || !output.has("results")) {
throw new ModelClientException(
provider() + " rerank 响应缺少 results",
ModelClientErrorType.INVALID_RESPONSE, null);
}
return output;
}
}
代码不短,但逻辑分层清晰:rerank 方法做前置处理(去重 + 短路),doRerank 做实际的 HTTP 调用和响应解析。逐段拆解。
4. 去重:混合检索的重复问题
rerank 方法入口处有一段去重逻辑:
List<RetrievedChunk> dedup = new ArrayList<>(candidates.size());
Set<String> seen = new HashSet<>();
for (RetrievedChunk rc : candidates) {
if (seen.add(rc.getId())) {
dedup.add(rc);
}
}
为什么需要去重?
Ragent 的检索阶段采用混合检索——向量检索和 BM25 关键词检索并行执行,结果合并后做 RRF 排名融合。同一个 Chunk 可能同时被向量检索和 BM25 命中。合并后的候选列表中,这个 Chunk 就出现了两次。
当然,这里其实是为了兜底,从 Ragent AI 角度出发,会在搜索后的后置处理器里优先去重。
如果不去重就传给 Rerank API,documents 数组里有两条完全一样的文本,白白浪费 Token(Rerank 模型按 Token 计费),而且返回的 results 中可能出现两条指向同一个 Chunk 的结果,占掉 topN 中的两个名额。
去重按 id 做——HashSet<String> seen 记录已出现的 id,seen.add(rc.getId()) 返回 true 表示第一次出现,加入 dedup;返回 false 表示已存在,跳过。保留第一个出现的,后续重复的丢弃。
为什么按 id 去重而不是按 text?
不同文档中可能碰巧有相同的文本片段(比如通用的免责声明),它们是不同的 Chunk,有不同的 id,不应该被去重。按 id 去重是精确去重——只过滤真正的重复条目(同一个 Chunk 被多条检索通路命中),不误杀内容相同但来源不同的 Chunk。
5. 短路优化
if (topN <= 0 || dedup.size() <= topN) {
return dedup;
}
去重后如果候选数不超过 topN,没必要调 Rerank API——本来就不需要裁剪,全部返回就行。这种情况在候选较少或 topN 设得较大时会触发,省去一次 HTTP 调用的开销。