diff --git a/src/it/java/io/weaviate/integration/SearchITest.java b/src/it/java/io/weaviate/integration/SearchITest.java index a812e5d91..b0feb6930 100644 --- a/src/it/java/io/weaviate/integration/SearchITest.java +++ b/src/it/java/io/weaviate/integration/SearchITest.java @@ -32,6 +32,7 @@ import io.weaviate.client6.v1.api.collections.generate.GenerativeObject; import io.weaviate.client6.v1.api.collections.generate.TaskOutput; import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; +import io.weaviate.client6.v1.api.collections.query.Diversity; import io.weaviate.client6.v1.api.collections.query.FetchObjectById; import io.weaviate.client6.v1.api.collections.query.Filter; import io.weaviate.client6.v1.api.collections.query.GroupBy; @@ -884,4 +885,17 @@ public void testQueryProfile_groupBy() throws Exception { InstanceOfAssertFactories.map(String.class, String.class)) .isNotEmpty()); } + + @Test + public void testDiversity() throws Exception { + Version.V137.orSkip(); + + var things = client.collections.use(COLLECTION); + var resp = things.query.nearVector(searchVector, + opt -> opt.diversity(Diversity.mmr(div -> div.limit(3)))); + + Assertions.assertThat(resp) + .extracting(QueryResponse::objects, InstanceOfAssertFactories.LIST) + .hasSize(3); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseVectorSearchBuilder.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseVectorSearchBuilder.java index 31e4c4075..1d78b4698 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseVectorSearchBuilder.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseVectorSearchBuilder.java @@ -7,6 +7,7 @@ abstract class BaseVectorSearchBuilder { + public enum Kind { + MMR; + } + + /** Use {@link Mmr} diversity selection. */ + public static Mmr mmr(Function> fn) { + return Mmr.of(fn); + } + + /** Return this diversity as concrete {@link Mmr} type. */ + public default Mmr asMmr() { + return _as(Diversity.Kind.MMR); + } + + /** Maximal Marginal Relevance diversity selection. */ + public record Mmr(Integer limit, Float balance) implements Diversity { + + @Override + public Diversity.Kind _kind() { + return Diversity.Kind.MMR; + } + + @Override + public Object _self() { + return this; + } + + public static Mmr of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public Mmr(Builder builder) { + this(builder.limit, builder.balance); + } + + public static class Builder implements ObjectBuilder { + private Integer limit; + private Float balance; + + public Builder limit(int limit) { + this.limit = limit; + return this; + } + + public Builder balance(float balance) { + this.balance = balance; + return this; + } + + @Override + public Mmr build() { + return new Mmr(this); + } + } + } + + public default WeaviateProtoBaseSearch.Selection.Builder toProto() { + var selection = WeaviateProtoBaseSearch.Selection.newBuilder(); + switch (_kind()) { + case MMR: + var mmrBuilder = WeaviateProtoBaseSearch.Selection.MMR.newBuilder(); + var mmr = asMmr(); + if (mmr.limit() != null) { + mmrBuilder.setLimit(mmr.limit()); + } + if (mmr.balance() != null) { + mmrBuilder.setBalance(mmr.balance()); + } + selection.setMmr(mmrBuilder); + break; + } + return selection; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearAudio.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearAudio.java index aa943849f..ec80c9e76 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearAudio.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearAudio.java @@ -10,7 +10,13 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; -public record NearAudio(Target searchTarget, Float distance, Float certainty, Rerank rerank, BaseQueryOptions common) +public record NearAudio( + Target searchTarget, + Float distance, + Float certainty, + Rerank rerank, + Diversity diversity, + BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { public static NearAudio of(String audio) { @@ -35,6 +41,7 @@ public NearAudio(Builder builder) { builder.distance, builder.certainty, builder.rerank, + builder.diversity, builder.baseOptions()); } @@ -81,7 +88,9 @@ private WeaviateProtoBaseSearch.NearAudioSearch.Builder protoBuilder() { } else if (distance != null) { nearAudio.setDistance(distance); } - + if (diversity != null) { + nearAudio.setSelection(diversity.toProto()); + } return nearAudio; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearDepth.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearDepth.java index 7405242dc..cd64b7dfd 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearDepth.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearDepth.java @@ -15,6 +15,7 @@ public record NearDepth( Float distance, Float certainty, Rerank rerank, + Diversity diversity, BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { @@ -40,6 +41,7 @@ public NearDepth(Builder builder) { builder.distance, builder.certainty, builder.rerank, + builder.diversity, builder.baseOptions()); } @@ -86,6 +88,9 @@ private WeaviateProtoBaseSearch.NearDepthSearch.Builder protoBuilder() { } else if (distance != null) { nearDepth.setDistance(distance); } + if (diversity != null) { + nearDepth.setSelection(diversity.toProto()); + } return nearDepth; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearImage.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearImage.java index 960368807..70b53f56d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearImage.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearImage.java @@ -15,6 +15,7 @@ public record NearImage( Float distance, Float certainty, Rerank rerank, + Diversity diversity, BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { @@ -40,6 +41,7 @@ public NearImage(Builder builder) { builder.distance, builder.certainty, builder.rerank, + builder.diversity, builder.baseOptions()); } @@ -86,6 +88,9 @@ private WeaviateProtoBaseSearch.NearImageSearch.Builder protoBuilder() { } else if (distance != null) { nearImage.setDistance(distance); } + if (diversity != null) { + nearImage.setSelection(diversity.toProto()); + } return nearImage; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearImu.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearImu.java index 420fcde6e..b38432516 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearImu.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearImu.java @@ -15,6 +15,7 @@ public record NearImu( Float distance, Float certainty, Rerank rerank, + Diversity diversity, BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { @@ -40,6 +41,7 @@ public NearImu(Builder builder) { builder.distance, builder.certainty, builder.rerank, + builder.diversity, builder.baseOptions()); } @@ -86,6 +88,10 @@ private WeaviateProtoBaseSearch.NearIMUSearch.Builder protoBuilder() { } else if (distance != null) { nearImu.setDistance(distance); } + if (diversity != null) { + nearImu.setSelection(diversity.toProto()); + } + return nearImu; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearObject.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearObject.java index 9f08e5243..9a50c6130 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearObject.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearObject.java @@ -13,6 +13,7 @@ public record NearObject( Float distance, Float certainty, Rerank rerank, + Diversity diversity, BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { @@ -30,6 +31,7 @@ public NearObject(Builder builder) { builder.distance, builder.certainty, builder.rerank, + builder.diversity, builder.baseOptions()); } @@ -74,6 +76,9 @@ private WeaviateProtoBaseSearch.NearObject.Builder protoBuilder() { } else if (distance != null) { nearObject.setDistance(distance); } + if (diversity != null) { + nearObject.setSelection(diversity.toProto()); + } return nearObject; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearText.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearText.java index 85509acd4..817223ec1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearText.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearText.java @@ -20,6 +20,7 @@ public record NearText( Rerank rerank, Move moveTo, Move moveAway, + Diversity diversity, BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { public static NearText of(String... concepts) { @@ -46,6 +47,7 @@ public NearText(Builder builder) { builder.rerank, builder.moveTo, builder.moveAway, + builder.diversity, builder.baseOptions()); } @@ -166,7 +168,9 @@ WeaviateProtoBaseSearch.NearTextSearch.Builder protoBuilder(boolean withTargets) moveAway.appendTo(away); nearText.setMoveAway(away); } - + if (diversity != null) { + nearText.setSelection(diversity.toProto()); + } return nearText; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearThermal.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearThermal.java index dc91ed03e..bfca43c94 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearThermal.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearThermal.java @@ -14,6 +14,7 @@ public record NearThermal(Target searchTarget, Float distance, Float certainty, Rerank rerank, + Diversity diversity, BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { @@ -39,6 +40,7 @@ public NearThermal(Builder builder) { builder.distance, builder.certainty, builder.rerank, + builder.diversity, builder.baseOptions()); } @@ -85,6 +87,9 @@ private WeaviateProtoBaseSearch.NearThermalSearch.Builder protoBuilder() { } else if (distance != null) { nearThermal.setDistance(distance); } + if (diversity != null) { + nearThermal.setSelection(diversity.toProto()); + } return nearThermal; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java index 134c019c9..842abcee2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java @@ -12,6 +12,7 @@ public record NearVector(NearVectorTarget searchTarget, Float distance, Float certainty, Rerank rerank, + Diversity diversity, BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { @@ -44,6 +45,7 @@ public NearVector(Builder builder) { builder.distance, builder.certainty, builder.rerank, + builder.diversity, builder.baseOptions()); } @@ -90,6 +92,10 @@ WeaviateProtoBaseSearch.NearVector.Builder protoBuilder(boolean withTargets) { } else if (distance != null) { nearVector.setDistance(distance); } + + if (diversity != null) { + nearVector.setSelection(diversity.toProto()); + } return nearVector; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVideo.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVideo.java index fb8974216..90e5d6ca0 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVideo.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVideo.java @@ -15,6 +15,7 @@ public record NearVideo( Float distance, Float certainty, Rerank rerank, + Diversity diversity, BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { @@ -40,6 +41,7 @@ public NearVideo(Builder builder) { builder.distance, builder.certainty, builder.rerank, + builder.diversity, builder.baseOptions()); } @@ -86,6 +88,9 @@ private WeaviateProtoBaseSearch.NearVideoSearch.Builder protoBuilder() { } else if (distance != null) { nearVideo.setDistance(distance); } + if (diversity != null) { + nearVideo.setSelection(diversity.toProto()); + } return nearVideo; } }