Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/it/java/io/weaviate/integration/SearchITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.weaviate.client6.v1.api.collections.generate.GenerativeObject;
import io.weaviate.client6.v1.api.collections.generate.TaskOutput;
import io.weaviate.client6.v1.api.collections.generative.DummyGenerative;
import io.weaviate.client6.v1.api.collections.query.Diversity;
import io.weaviate.client6.v1.api.collections.query.FetchObjectById;
import io.weaviate.client6.v1.api.collections.query.Filter;
import io.weaviate.client6.v1.api.collections.query.GroupBy;
Expand Down Expand Up @@ -884,4 +885,17 @@ public void testQueryProfile_groupBy() throws Exception {
InstanceOfAssertFactories.map(String.class, String.class))
.isNotEmpty());
}

@Test
public void testDiversity() throws Exception {
Version.V137.orSkip();

var things = client.collections.use(COLLECTION);
var resp = things.query.nearVector(searchVector,
opt -> opt.diversity(Diversity.mmr(div -> div.limit(3))));

Assertions.assertThat(resp)
.extracting(QueryResponse::objects, InstanceOfAssertFactories.LIST)
.hasSize(3);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ abstract class BaseVectorSearchBuilder<SelfT extends BaseVectorSearchBuilder<Sel
Float distance;
Float certainty;
Rerank rerank;
Diversity diversity;

/**
* Discard objects whose vectors are further away
Expand Down Expand Up @@ -45,4 +46,13 @@ public SelfT rerank(Rerank rerank) {
this.rerank = rerank;
return (SelfT) this;
}

/**
* Apply diversity selection to the query results.
*/
@SuppressWarnings("unchecked")
public SelfT diversity(Diversity diversity) {
this.diversity = diversity;
return (SelfT) this;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package io.weaviate.client6.v1.api.collections.query;

import java.util.function.Function;

import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.TaggedUnion;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;

public interface Diversity extends TaggedUnion<Diversity.Kind, Object> {
public enum Kind {
MMR;
}

/** Use {@link Mmr} diversity selection. */
public static Mmr mmr(Function<Mmr.Builder, ObjectBuilder<Mmr>> fn) {
return Mmr.of(fn);
}

/** Return this diversity as concrete {@link Mmr} type. */
public default Mmr asMmr() {
return _as(Diversity.Kind.MMR);
}

/** Maximal Marginal Relevance diversity selection. */
public record Mmr(Integer limit, Float balance) implements Diversity {

@Override
public Diversity.Kind _kind() {
return Diversity.Kind.MMR;
}

@Override
public Object _self() {
return this;
}

public static Mmr of(Function<Builder, ObjectBuilder<Mmr>> fn) {
return fn.apply(new Builder()).build();
}

public Mmr(Builder builder) {
this(builder.limit, builder.balance);
}

public static class Builder implements ObjectBuilder<Mmr> {
private Integer limit;
private Float balance;

public Builder limit(int limit) {
this.limit = limit;
return this;
}

public Builder balance(float balance) {
this.balance = balance;
return this;
}

@Override
public Mmr build() {
return new Mmr(this);
}
}
}

public default WeaviateProtoBaseSearch.Selection.Builder toProto() {
var selection = WeaviateProtoBaseSearch.Selection.newBuilder();
switch (_kind()) {
case MMR:
var mmrBuilder = WeaviateProtoBaseSearch.Selection.MMR.newBuilder();
var mmr = asMmr();
if (mmr.limit() != null) {
mmrBuilder.setLimit(mmr.limit());
}
if (mmr.balance() != null) {
mmrBuilder.setBalance(mmr.balance());
}
selection.setMmr(mmrBuilder);
break;
}
return selection;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;

public record NearAudio(Target searchTarget, Float distance, Float certainty, Rerank rerank, BaseQueryOptions common)
public record NearAudio(
Target searchTarget,
Float distance,
Float certainty,
Rerank rerank,
Diversity diversity,
BaseQueryOptions common)
implements QueryOperator, AggregateObjectFilter {

public static NearAudio of(String audio) {
Expand All @@ -35,6 +41,7 @@ public NearAudio(Builder builder) {
builder.distance,
builder.certainty,
builder.rerank,
builder.diversity,
builder.baseOptions());
}

Expand Down Expand Up @@ -81,7 +88,9 @@ private WeaviateProtoBaseSearch.NearAudioSearch.Builder protoBuilder() {
} else if (distance != null) {
nearAudio.setDistance(distance);
}

if (diversity != null) {
nearAudio.setSelection(diversity.toProto());
}
return nearAudio;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public record NearDepth(
Float distance,
Float certainty,
Rerank rerank,
Diversity diversity,
BaseQueryOptions common)
implements QueryOperator, AggregateObjectFilter {

Expand All @@ -40,6 +41,7 @@ public NearDepth(Builder builder) {
builder.distance,
builder.certainty,
builder.rerank,
builder.diversity,
builder.baseOptions());
}

Expand Down Expand Up @@ -86,6 +88,9 @@ private WeaviateProtoBaseSearch.NearDepthSearch.Builder protoBuilder() {
} else if (distance != null) {
nearDepth.setDistance(distance);
}
if (diversity != null) {
nearDepth.setSelection(diversity.toProto());
}
return nearDepth;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public record NearImage(
Float distance,
Float certainty,
Rerank rerank,
Diversity diversity,
BaseQueryOptions common)
implements QueryOperator, AggregateObjectFilter {

Expand All @@ -40,6 +41,7 @@ public NearImage(Builder builder) {
builder.distance,
builder.certainty,
builder.rerank,
builder.diversity,
builder.baseOptions());
}

Expand Down Expand Up @@ -86,6 +88,9 @@ private WeaviateProtoBaseSearch.NearImageSearch.Builder protoBuilder() {
} else if (distance != null) {
nearImage.setDistance(distance);
}
if (diversity != null) {
nearImage.setSelection(diversity.toProto());
}
return nearImage;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public record NearImu(
Float distance,
Float certainty,
Rerank rerank,
Diversity diversity,
BaseQueryOptions common)
implements QueryOperator, AggregateObjectFilter {

Expand All @@ -40,6 +41,7 @@ public NearImu(Builder builder) {
builder.distance,
builder.certainty,
builder.rerank,
builder.diversity,
builder.baseOptions());
}

Expand Down Expand Up @@ -86,6 +88,10 @@ private WeaviateProtoBaseSearch.NearIMUSearch.Builder protoBuilder() {
} else if (distance != null) {
nearImu.setDistance(distance);
}
if (diversity != null) {
nearImu.setSelection(diversity.toProto());
}

return nearImu;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public record NearObject(
Float distance,
Float certainty,
Rerank rerank,
Diversity diversity,
BaseQueryOptions common)
implements QueryOperator, AggregateObjectFilter {

Expand All @@ -30,6 +31,7 @@ public NearObject(Builder builder) {
builder.distance,
builder.certainty,
builder.rerank,
builder.diversity,
builder.baseOptions());
}

Expand Down Expand Up @@ -74,6 +76,9 @@ private WeaviateProtoBaseSearch.NearObject.Builder protoBuilder() {
} else if (distance != null) {
nearObject.setDistance(distance);
}
if (diversity != null) {
nearObject.setSelection(diversity.toProto());
}
return nearObject;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public record NearText(
Rerank rerank,
Move moveTo,
Move moveAway,
Diversity diversity,
BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter {

public static NearText of(String... concepts) {
Expand All @@ -46,6 +47,7 @@ public NearText(Builder builder) {
builder.rerank,
builder.moveTo,
builder.moveAway,
builder.diversity,
builder.baseOptions());
}

Expand Down Expand Up @@ -166,7 +168,9 @@ WeaviateProtoBaseSearch.NearTextSearch.Builder protoBuilder(boolean withTargets)
moveAway.appendTo(away);
nearText.setMoveAway(away);
}

if (diversity != null) {
nearText.setSelection(diversity.toProto());
}
return nearText;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public record NearThermal(Target searchTarget,
Float distance,
Float certainty,
Rerank rerank,
Diversity diversity,
BaseQueryOptions common)
implements QueryOperator, AggregateObjectFilter {

Expand All @@ -39,6 +40,7 @@ public NearThermal(Builder builder) {
builder.distance,
builder.certainty,
builder.rerank,
builder.diversity,
builder.baseOptions());
}

Expand Down Expand Up @@ -85,6 +87,9 @@ private WeaviateProtoBaseSearch.NearThermalSearch.Builder protoBuilder() {
} else if (distance != null) {
nearThermal.setDistance(distance);
}
if (diversity != null) {
nearThermal.setSelection(diversity.toProto());
}
return nearThermal;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public record NearVector(NearVectorTarget searchTarget,
Float distance,
Float certainty,
Rerank rerank,
Diversity diversity,
BaseQueryOptions common)
implements QueryOperator, AggregateObjectFilter {

Expand Down Expand Up @@ -44,6 +45,7 @@ public NearVector(Builder builder) {
builder.distance,
builder.certainty,
builder.rerank,
builder.diversity,
builder.baseOptions());
}

Expand Down Expand Up @@ -90,6 +92,10 @@ WeaviateProtoBaseSearch.NearVector.Builder protoBuilder(boolean withTargets) {
} else if (distance != null) {
nearVector.setDistance(distance);
}

if (diversity != null) {
nearVector.setSelection(diversity.toProto());
}
return nearVector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public record NearVideo(
Float distance,
Float certainty,
Rerank rerank,
Diversity diversity,
BaseQueryOptions common)
implements QueryOperator, AggregateObjectFilter {

Expand All @@ -40,6 +41,7 @@ public NearVideo(Builder builder) {
builder.distance,
builder.certainty,
builder.rerank,
builder.diversity,
builder.baseOptions());
}

Expand Down Expand Up @@ -86,6 +88,9 @@ private WeaviateProtoBaseSearch.NearVideoSearch.Builder protoBuilder() {
} else if (distance != null) {
nearVideo.setDistance(distance);
}
if (diversity != null) {
nearVideo.setSelection(diversity.toProto());
}
return nearVideo;
}
}
Loading