Skip to content

Commit 4635ee6

Browse files
committed
tracking internal data in the train bucket
1 parent 518d4fd commit 4635ee6

2 files changed

Lines changed: 61 additions & 48 deletions

File tree

index/scorch/merge.go

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ package scorch
1717
import (
1818
"context"
1919
"fmt"
20-
"math"
2120
"os"
2221
"strings"
2322
"sync"
@@ -484,7 +483,6 @@ type mergedSegmentHistory struct {
484483
type segmentMerge struct {
485484
id []uint64
486485
new []segment.Segment
487-
trainData [][]float32
488486
mergedSegHistory map[uint64]*mergedSegmentHistory
489487
notifyCh chan *mergeTaskIntroStatus
490488
mmaped uint32
@@ -531,27 +529,6 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot,
531529
var em sync.Mutex
532530
var errs []error
533531

534-
var trainingSample [][]float32
535-
collectTrainData := func(segTrainData [][]float32) {
536-
trainingSample = append(trainingSample, segTrainData...)
537-
}
538-
539-
// numDocs, err := snapshot.DocCount()
540-
// if err != nil {
541-
// return nil, nil, err
542-
// }
543-
544-
// harcoding the total docs for now, need to get it from CB level
545-
numDocs := 1000000
546-
trainingSampleSize := math.Ceil(4 * math.Sqrt(float64(numDocs)) * 50)
547-
548-
// collect train data only if needed
549-
if len(snapshot.trainData)/768 < int(trainingSampleSize) {
550-
s.segmentConfig["collectTrainDataCallback"] = collectTrainData
551-
} else {
552-
s.segmentConfig["trainData"] = snapshot.trainData
553-
}
554-
555532
// deploy the workers to merge and flush the batches of segments concurrently
556533
// and create a new file segment
557534
for i := 0; i < numFlushes; i++ {

index/scorch/train_vector.go

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"fmt"
2424
"os"
2525
"path/filepath"
26+
"strconv"
2627
"strings"
2728
"sync"
2829

@@ -36,10 +37,13 @@ import (
3637
)
3738

3839
type trainRequest struct {
39-
vecCount int
40-
ackCh chan error
41-
trainComplete []byte
42-
sample segment.Segment
40+
ackCh chan error
41+
sample segment.Segment
42+
43+
// metadata to write out to bolt
44+
trainComplete bool
45+
vecCount uint64
46+
internalData []byte
4347
}
4448

4549
func initTrainer(s *Scorch, config map[string]interface{}) *vectorTrainer {
@@ -89,6 +93,8 @@ func (t *vectorTrainer) trainLoop() {
8993
t.parent.asyncTasks.Done()
9094
}()
9195
// initialize stuff
96+
totalSamplesProcessed := t.centroidIndex.cachedMeta.fetchMeta("trainedSamples").(uint64)
97+
buf := make([]byte, binary.MaxVarintLen64)
9298
t.parent.segmentConfig[index.CentroidIndexCallback] = t.getCentroidIndex
9399
path := filepath.Join(t.parent.path, index.CentroidIndexFileName)
94100
for {
@@ -108,7 +114,7 @@ func (t *vectorTrainer) trainLoop() {
108114
}
109115
default:
110116
}
111-
} else {
117+
} else if sampleSeg != nil {
112118
// merge the new segment with the existing one, no need to persist?
113119
// persist in a tmp file and then rename - is that a fair strategy?
114120
t.parent.segmentConfig[index.TrainingKey] = true
@@ -163,29 +169,44 @@ func (t *vectorTrainer) trainLoop() {
163169
return
164170
}
165171

172+
// update the path of the centroid index segmentSnaphshot in bolt
166173
err = trainerBucket.Put(util.BoltPathKey, []byte(index.CentroidIndexFileName))
167174
if err != nil {
168175
trainReq.ackCh <- fmt.Errorf("error updating centroid bucket: %v", err)
169176
close(trainReq.ackCh)
170177
return
171178
}
172179

173-
err = trainerBucket.Put(util.BoltTrainCompleteKey, trainReq.trainComplete)
180+
// train status - value is controlled by the application layer
181+
var comp byte
182+
if trainReq.trainComplete {
183+
comp = 1
184+
}
185+
err = trainerBucket.Put(util.BoltTrainCompleteKey, []byte{comp})
174186
if err != nil {
175187
trainReq.ackCh <- fmt.Errorf("error updating train complete bucket: %v", err)
176188
close(trainReq.ackCh)
177189
return
178190
}
179191

180-
// track progress of training in temrs of samples processed
181-
binary.LittleEndian.PutUint64(buf, uint64(totalSamplesProcessed))
192+
totalSamplesProcessed += trainReq.vecCount
193+
// track progress of training in terms of samples processed
194+
binary.LittleEndian.PutUint64(buf, totalSamplesProcessed)
182195
err = trainerBucket.Put(util.BoltTrainedSamplesKey, buf)
183196
if err != nil {
184197
trainReq.ackCh <- fmt.Errorf("error updating trained samples bucket: %v", err)
185198
close(trainReq.ackCh)
186199
return
187200
}
188201

202+
// training related internal data that needs to be stored as per
203+
// application layer via the SetInternal and GetInternal APIs
204+
err = trainerBucket.Put(util.BoltInternalKey, trainReq.internalData)
205+
if err != nil {
206+
trainReq.ackCh <- fmt.Errorf("error updating train internal bucket: %v", err)
207+
close(trainReq.ackCh)
208+
return
209+
}
189210
err = tx.Commit()
190211
if err != nil {
191212
trainReq.ackCh <- fmt.Errorf("error committing bolt transaction: %v", err)
@@ -226,6 +247,15 @@ func (t *vectorTrainer) loadTrainedData(bucket *bolt.Bucket) error {
226247
if err != nil {
227248
return err
228249
}
250+
251+
internalData := bucket.Get(util.BoltInternalKey)
252+
trainedSamples := bucket.Get(util.BoltTrainedSamplesKey)
253+
trainComplete := bucket.Get(util.BoltTrainCompleteKey)
254+
255+
segmentSnapshot.cachedMeta.updateMeta("internalData", internalData)
256+
segmentSnapshot.cachedMeta.updateMeta("trainComplete", trainComplete)
257+
segmentSnapshot.cachedMeta.updateMeta("trainedSamples", binary.LittleEndian.Uint64(trainedSamples))
258+
229259
t.m.Lock()
230260
defer t.m.Unlock()
231261
t.centroidIndex = segmentSnapshot
@@ -250,38 +280,37 @@ func (t *vectorTrainer) train(batch *index.Batch) error {
250280
}
251281
}
252282

253-
// just builds a new vector index out of the train data provided
254-
// this is not necessarily the final train data since this is submitted
255-
// as a request to the trainer component to be merged. once the training
256-
// is complete, the template will be used for other operations down the line
257-
// like merge and search.
258-
//
259-
// note: this might index text data too, how to handle this? s.segmentConfig?
260-
// todo: updates/deletes -> data drift detection
261-
seg, _, err := t.parent.segPlugin.NewUsing(trainData, t.parent.segmentConfig)
283+
var seg segment.Segment
284+
var fin bool
285+
var err error
286+
trainComplete := batch.InternalOps[string(util.BoltTrainCompleteKey)]
287+
if trainComplete == nil {
288+
trainComplete = []byte("false")
289+
}
290+
fin, err = strconv.ParseBool(string(trainComplete))
262291
if err != nil {
263292
return fmt.Errorf("error parsing train complete: %v", err)
264293
}
265294

266295
if !fin {
267296
// just builds a new vector index out of the train data provided
268-
// it'll be an IVF index so the centroids are computed at this stage and
269-
// this template will be used in the indexing down the line to index
270-
// the data vectors. s.segmentConfig will mark this as a training phase
271-
// and zap will handle it accordingly.
297+
// this is not necessarily the final train data since this is submitted
298+
// as a request to the trainer component to be merged. once the training
299+
// is complete, the template will be used for other operations down the line
300+
// like merge and search.
272301
//
273302
// note: this might index text data too, how to handle this? s.segmentConfig?
274303
// todo: updates/deletes -> data drift detection
275-
seg, _, err = t.parent.segPlugin.NewEx(trainData, t.parent.segmentConfig)
304+
seg, _, err = t.parent.segPlugin.NewUsing(trainData, t.parent.segmentConfig)
276305
if err != nil {
277306
return err
278307
}
279308
}
280309

281310
trainReq := &trainRequest{
311+
vecCount: uint64(len(trainData)),
282312
sample: seg,
283-
vecCount: len(trainData), // todo: multivector support
284-
trainComplete: trainComplete,
313+
trainComplete: fin,
285314
ackCh: make(chan error),
286315
}
287316

@@ -301,11 +330,18 @@ func (t *vectorTrainer) getInternal(key []byte) ([]byte, error) {
301330
switch string(key) {
302331
case string(util.BoltTrainCompleteKey):
303332
trainComplete := t.centroidIndex.cachedMeta.fetchMeta("trainComplete").([]byte)
304-
return trainComplete, nil
333+
if trainComplete[0] == 1 {
334+
return []byte("true"), nil
335+
}
336+
return []byte("false"), nil
305337
case string(util.BoltTrainedSamplesKey):
306338
// keep rv in a human readable format
307339
trainedSamples := t.centroidIndex.cachedMeta.fetchMeta("trainedSamples").(uint64)
308340
return []byte(fmt.Sprintf("%d", trainedSamples)), nil
341+
// keep the default fetch to be from the internal data
342+
default:
343+
internalData := t.centroidIndex.cachedMeta.fetchMeta("internalData").([]byte)
344+
return internalData, nil
309345
}
310346
}
311347
return nil, nil

0 commit comments

Comments
 (0)