diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 2d138ce2ce..087d126b63 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -3431,6 +3431,16 @@ func (node *Proxy) handleIfSearchByPK(ctx context.Context, request *milvuspb.Sea return nil // Not search by PK, do nothing } + // Check for duplicate IDs (fail fast before query) + inputIDsCount := typeutil.GetSizeOfIDs(ids) + checker, err := typeutil.NewIDsChecker(ids) + if err != nil { + return err + } + if checker.Size() != inputIDsCount { + return merr.WrapErrParameterInvalidMsg("duplicate IDs found in search request") + } + // Get collection schema for validation and plan building collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, request.GetDbName(), request.GetCollectionName(), 0) @@ -3438,11 +3448,19 @@ func (node *Proxy) handleIfSearchByPK(ctx context.Context, request *milvuspb.Sea return err } - // Validate that anns_field is provided + // Get anns_field from search params, or infer from schema if only one vector field exists annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, request.SearchParams) if err != nil || annsFieldName == "" { - return merr.WrapErrParameterInvalid("valid anns_field in search_params", "missing", - "anns_field is required for search by IDs") + vecFields := typeutil.GetVectorFieldSchemas(collectionInfo.schema.CollectionSchema) + if len(vecFields) == 0 { + return merr.WrapErrParameterInvalid("valid anns_field in search_params", "missing", + "no vector field found in schema") + } + if enableMultipleVectorFields && len(vecFields) > 1 { + return merr.WrapErrParameterInvalid("valid anns_field in search_params", "missing", + "multiple vector fields exist, please specify anns_field in search_params") + } + annsFieldName = vecFields[0].Name } annField := typeutil.GetFieldByName(collectionInfo.schema.CollectionSchema, annsFieldName) @@ -3527,7 +3545,6 @@ func (node *Proxy) handleIfSearchByPK(ctx context.Context, request *milvuspb.Sea } // Check if the returned pk count matches the input IDs count - inputIDsCount := typeutil.GetSizeOfIDs(ids) returnedPKCount := typeutil.GetPKSize(pkFieldData) if returnedPKCount != inputIDsCount { // Find which IDs are missing