@@ -23,6 +23,7 @@ import (
2323 "fmt"
2424 "os"
2525 "path/filepath"
26+ "strconv"
2627 "strings"
2728 "sync"
2829
@@ -36,10 +37,13 @@ import (
3637)
3738
3839type trainRequest struct {
39- vecCount int
40- ackCh chan error
41- trainComplete []byte
42- sample segment.Segment
40+ ackCh chan error
41+ sample segment.Segment
42+
43+ // metadata to write out to bolt
44+ trainComplete bool
45+ vecCount uint64
46+ internalData []byte
4347}
4448
4549func initTrainer (s * Scorch , config map [string ]interface {}) * vectorTrainer {
@@ -89,6 +93,8 @@ func (t *vectorTrainer) trainLoop() {
8993 t .parent .asyncTasks .Done ()
9094 }()
9195 // initialize stuff
96+ totalSamplesProcessed := t .centroidIndex .cachedMeta .fetchMeta ("trainedSamples" ).(uint64 )
97+ buf := make ([]byte , binary .MaxVarintLen64 )
9298 t .parent .segmentConfig [index .CentroidIndexCallback ] = t .getCentroidIndex
9399 path := filepath .Join (t .parent .path , index .CentroidIndexFileName )
94100 for {
@@ -108,7 +114,7 @@ func (t *vectorTrainer) trainLoop() {
108114 }
109115 default :
110116 }
111- } else {
117+ } else if sampleSeg != nil {
112118 // merge the new segment with the existing one, no need to persist?
113119 // persist in a tmp file and then rename - is that a fair strategy?
114120 t .parent .segmentConfig [index .TrainingKey ] = true
@@ -163,29 +169,44 @@ func (t *vectorTrainer) trainLoop() {
163169 return
164170 }
165171
172+ // update the path of the centroid index segmentSnaphshot in bolt
166173 err = trainerBucket .Put (util .BoltPathKey , []byte (index .CentroidIndexFileName ))
167174 if err != nil {
168175 trainReq .ackCh <- fmt .Errorf ("error updating centroid bucket: %v" , err )
169176 close (trainReq .ackCh )
170177 return
171178 }
172179
173- err = trainerBucket .Put (util .BoltTrainCompleteKey , trainReq .trainComplete )
180+ // train status - value is controlled by the application layer
181+ var comp byte
182+ if trainReq .trainComplete {
183+ comp = 1
184+ }
185+ err = trainerBucket .Put (util .BoltTrainCompleteKey , []byte {comp })
174186 if err != nil {
175187 trainReq .ackCh <- fmt .Errorf ("error updating train complete bucket: %v" , err )
176188 close (trainReq .ackCh )
177189 return
178190 }
179191
180- // track progress of training in temrs of samples processed
181- binary .LittleEndian .PutUint64 (buf , uint64 (totalSamplesProcessed ))
192+ totalSamplesProcessed += trainReq .vecCount
193+ // track progress of training in terms of samples processed
194+ binary .LittleEndian .PutUint64 (buf , totalSamplesProcessed )
182195 err = trainerBucket .Put (util .BoltTrainedSamplesKey , buf )
183196 if err != nil {
184197 trainReq .ackCh <- fmt .Errorf ("error updating trained samples bucket: %v" , err )
185198 close (trainReq .ackCh )
186199 return
187200 }
188201
202+ // training related internal data that needs to be stored as per
203+ // application layer via the SetInternal and GetInternal APIs
204+ err = trainerBucket .Put (util .BoltInternalKey , trainReq .internalData )
205+ if err != nil {
206+ trainReq .ackCh <- fmt .Errorf ("error updating train internal bucket: %v" , err )
207+ close (trainReq .ackCh )
208+ return
209+ }
189210 err = tx .Commit ()
190211 if err != nil {
191212 trainReq .ackCh <- fmt .Errorf ("error committing bolt transaction: %v" , err )
@@ -226,6 +247,15 @@ func (t *vectorTrainer) loadTrainedData(bucket *bolt.Bucket) error {
226247 if err != nil {
227248 return err
228249 }
250+
251+ internalData := bucket .Get (util .BoltInternalKey )
252+ trainedSamples := bucket .Get (util .BoltTrainedSamplesKey )
253+ trainComplete := bucket .Get (util .BoltTrainCompleteKey )
254+
255+ segmentSnapshot .cachedMeta .updateMeta ("internalData" , internalData )
256+ segmentSnapshot .cachedMeta .updateMeta ("trainComplete" , trainComplete )
257+ segmentSnapshot .cachedMeta .updateMeta ("trainedSamples" , binary .LittleEndian .Uint64 (trainedSamples ))
258+
229259 t .m .Lock ()
230260 defer t .m .Unlock ()
231261 t .centroidIndex = segmentSnapshot
@@ -250,38 +280,37 @@ func (t *vectorTrainer) train(batch *index.Batch) error {
250280 }
251281 }
252282
253- // just builds a new vector index out of the train data provided
254- // this is not necessarily the final train data since this is submitted
255- // as a request to the trainer component to be merged. once the training
256- // is complete, the template will be used for other operations down the line
257- // like merge and search.
258- //
259- // note: this might index text data too, how to handle this? s.segmentConfig?
260- // todo: updates/deletes -> data drift detection
261- seg , _ , err := t .parent .segPlugin .NewUsing (trainData , t .parent .segmentConfig )
283+ var seg segment.Segment
284+ var fin bool
285+ var err error
286+ trainComplete := batch .InternalOps [string (util .BoltTrainCompleteKey )]
287+ if trainComplete == nil {
288+ trainComplete = []byte ("false" )
289+ }
290+ fin , err = strconv .ParseBool (string (trainComplete ))
262291 if err != nil {
263292 return fmt .Errorf ("error parsing train complete: %v" , err )
264293 }
265294
266295 if ! fin {
267296 // just builds a new vector index out of the train data provided
268- // it'll be an IVF index so the centroids are computed at this stage and
269- // this template will be used in the indexing down the line to index
270- // the data vectors. s.segmentConfig will mark this as a training phase
271- // and zap will handle it accordingly .
297+ // this is not necessarily the final train data since this is submitted
298+ // as a request to the trainer component to be merged. once the training
299+ // is complete, the template will be used for other operations down the line
300+ // like merge and search .
272301 //
273302 // note: this might index text data too, how to handle this? s.segmentConfig?
274303 // todo: updates/deletes -> data drift detection
275- seg , _ , err = t .parent .segPlugin .NewEx (trainData , t .parent .segmentConfig )
304+ seg , _ , err = t .parent .segPlugin .NewUsing (trainData , t .parent .segmentConfig )
276305 if err != nil {
277306 return err
278307 }
279308 }
280309
281310 trainReq := & trainRequest {
311+ vecCount : uint64 (len (trainData )),
282312 sample : seg ,
283- vecCount : len (trainData ), // todo: multivector support
284- trainComplete : trainComplete ,
313+ trainComplete : fin ,
285314 ackCh : make (chan error ),
286315 }
287316
@@ -301,11 +330,18 @@ func (t *vectorTrainer) getInternal(key []byte) ([]byte, error) {
301330 switch string (key ) {
302331 case string (util .BoltTrainCompleteKey ):
303332 trainComplete := t .centroidIndex .cachedMeta .fetchMeta ("trainComplete" ).([]byte )
304- return trainComplete , nil
333+ if trainComplete [0 ] == 1 {
334+ return []byte ("true" ), nil
335+ }
336+ return []byte ("false" ), nil
305337 case string (util .BoltTrainedSamplesKey ):
306338 // keep rv in a human readable format
307339 trainedSamples := t .centroidIndex .cachedMeta .fetchMeta ("trainedSamples" ).(uint64 )
308340 return []byte (fmt .Sprintf ("%d" , trainedSamples )), nil
341+ // keep the default fetch to be from the internal data
342+ default :
343+ internalData := t .centroidIndex .cachedMeta .fetchMeta ("internalData" ).([]byte )
344+ return internalData , nil
309345 }
310346 }
311347 return nil , nil
0 commit comments