Skip to content

Commit 76381c3

Browse files
authored
Adding discrete control functionality to libemg (beta) (#129)
* Got data collection working * Added MVLDA and DTW * Discrete Working * Got the online classifier working * Updated based on copilots suggestions * Added torch * Added tslearn * Added docs for discrete stuff * Fixed documentation
1 parent de37630 commit 76381c3

16 files changed

Lines changed: 1073 additions & 67 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,4 @@ MLPR.py
7070
docs/Makefile
7171
sifibridge-*
7272
*.pyc
73+
*.model
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Discrete Classifiers
2+
3+
Unlike continuous classifiers that output a prediction for every window of EMG data, discrete classifiers are designed for recognizing transient, isolated gestures. These classifiers operate on variable-length templates (sequences of windows) and are well-suited for detecting distinct movements like finger snaps, taps, or quick hand gestures.
4+
5+
Discrete classifiers expect input data in a different format than continuous classifiers:
6+
- **Continuous classifiers**: Operate on individual windows of shape `(n_windows, n_features)`.
7+
- **Discrete classifiers**: Operate on templates (sequences of windows) where each template has shape `(n_frames, n_features)` and can vary in length.
8+
9+
To prepare data for discrete classifiers, use the `discrete=True` parameter when calling `parse_windows()` on your `OfflineDataHandler`:
10+
11+
```Python
12+
from libemg.data_handler import OfflineDataHandler
13+
14+
odh = OfflineDataHandler()
15+
odh.get_data('./data/', regex_filters)
16+
windows, metadata = odh.parse_windows(window_size=50, window_increment=10, discrete=True)
17+
# windows is now a list of templates, one per file/rep
18+
```
19+
20+
For feature extraction with discrete data, use the `discrete=True` parameter:
21+
22+
```Python
23+
from libemg.feature_extractor import FeatureExtractor
24+
25+
fe = FeatureExtractor()
26+
features = fe.extract_features(['MAV', 'ZC', 'SSC', 'WL'], windows, discrete=True, array=True)
27+
# features is a list of arrays, one per template
28+
```
29+
30+
## Majority Vote LDA (MVLDA)
31+
32+
A classifier that applies Linear Discriminant Analysis (LDA) to each frame within a template and uses majority voting to determine the final prediction. This approach is simple yet effective for discrete gesture recognition.
33+
34+
```Python
35+
from libemg._discrete_models import MVLDA
36+
37+
model = MVLDA()
38+
model.fit(train_features, train_labels)
39+
predictions = model.predict(test_features)
40+
probabilities = model.predict_proba(test_features)
41+
```
42+
43+
## Dynamic Time Warping Classifier (DTWClassifier)
44+
45+
A template-matching classifier that uses Dynamic Time Warping (DTW) distance to compare test samples against stored training templates. DTW is particularly useful when gestures may vary in speed or duration, as it can align sequences with different temporal characteristics.
46+
47+
```Python
48+
from libemg._discrete_models import DTWClassifier
49+
50+
model = DTWClassifier(n_neighbors=3)
51+
model.fit(train_features, train_labels)
52+
predictions = model.predict(test_features)
53+
probabilities = model.predict_proba(test_features)
54+
```
55+
56+
The `n_neighbors` parameter controls how many nearest templates are used for voting (k-nearest neighbors with DTW distance).
57+
58+
## Pretrained Myo Cross-User Model (MyoCrossUserPretrained)
59+
60+
A pretrained deep learning model for cross-user discrete gesture recognition using the Myo armband. This model uses a convolutional-recurrent architecture and recognizes 6 gestures: Nothing, Close, Flexion, Extension, Open, and Pinch.
61+
62+
```Python
63+
from libemg._discrete_models import MyoCrossUserPretrained
64+
65+
model = MyoCrossUserPretrained()
66+
# Model is automatically downloaded on first use
67+
68+
# The model provides recommended parameters for OnlineDiscreteClassifier
69+
print(model.args)
70+
# {'window_size': 10, 'window_increment': 5, 'null_label': 0, ...}
71+
72+
predictions = model.predict(test_data)
73+
probabilities = model.predict_proba(test_data)
74+
```
75+
76+
This model expects raw windowed EMG data (not extracted features) with shape `(batch_size, seq_len, n_channels, n_samples)`.
77+
78+
## Online Discrete Classification
79+
80+
For real-time discrete gesture recognition, use the `OnlineDiscreteClassifier`:
81+
82+
```Python
83+
from libemg.emg_predictor import OnlineDiscreteClassifier
84+
from libemg._discrete_models import MyoCrossUserPretrained
85+
86+
# Load pretrained model
87+
model = MyoCrossUserPretrained()
88+
89+
# Create online classifier
90+
classifier = OnlineDiscreteClassifier(
91+
odh=online_data_handler,
92+
model=model,
93+
window_size=model.args['window_size'],
94+
window_increment=model.args['window_increment'],
95+
null_label=model.args['null_label'],
96+
feature_list=model.args['feature_list'], # None for raw data
97+
template_size=model.args['template_size'],
98+
min_template_size=model.args['min_template_size'],
99+
gesture_mapping=model.args['gesture_mapping'],
100+
buffer_size=model.args['buffer_size'],
101+
rejection_threshold=0.5,
102+
debug=True
103+
)
104+
105+
# Start recognition loop
106+
classifier.run()
107+
```
108+
109+
## Creating Custom Discrete Classifiers
110+
111+
Any custom discrete classifier should implement the following methods to work with LibEMG:
112+
113+
- `fit(x, y)`: Train the model where `x` is a list of templates and `y` is the corresponding labels.
114+
- `predict(x)`: Return predicted class labels for a list of templates.
115+
- `predict_proba(x)`: Return predicted class probabilities for a list of templates.
116+
117+
```Python
118+
class CustomDiscreteClassifier:
119+
def __init__(self):
120+
self.classes_ = None
121+
122+
def fit(self, x, y):
123+
# x: list of templates (each template is an array of frames)
124+
# y: labels for each template
125+
self.classes_ = np.unique(y)
126+
# ... training logic
127+
128+
def predict(self, x):
129+
# Return array of predictions
130+
pass
131+
132+
def predict_proba(self, x):
133+
# Return array of shape (n_samples, n_classes)
134+
pass
135+
```

docs/source/documentation/prediction/prediction.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ EMG Prediction
66
.. include:: classification_doc.md
77
:parser: myst_parser.sphinx_
88

9+
.. include:: discrete_classification_doc.md
10+
:parser: myst_parser.sphinx_
11+
912
.. include:: regression_doc.md
1013
:parser: myst_parser.sphinx_
1114

docs/source/documentation/prediction/predictors.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
After recording, processing, and extracting features from a window of EMG data, it is passed to a machine learning algorithm for prediction. These control systems have evolved in the prosthetics community for continuously predicting muscular contractions for enabling prosthesis control. Therefore, they are primarily limited to recognizing static contractions (e.g., hand open/close and wrist flexion/extension) as they have no temporal awareness. Currently, this is the form of recognition supported by LibEMG and is an initial step to explore EMG as an interaction opportunity for general-purpose use. This section highlights the machine-learning strategies that are part of `LibEMG`'s pipeline.
44

5-
There are two types of models supported in `LibEMG`: classifiers and regressors. Classifiers output a discrete motion class for each window, whereas regressors output a continuous prediction along a degree of freedom. For both classifiers and regressors, `LibEMG` supports statistical models as well as deep learning models. Additionally, a number of post-processing methods (i.e., techniques to improve performance after prediction) are supported for all models.
5+
There are three types of models supported in `LibEMG`: classifiers, regressors, and discrete classifiers. Classifiers output a motion class for each window of EMG data, whereas regressors output a continuous prediction along a degree of freedom. Discrete classifiers are designed for recognizing transient, isolated gestures and operate on variable-length templates rather than individual windows. For classifiers and regressors, `LibEMG` supports statistical models as well as deep learning models. Additionally, a number of post-processing methods (i.e., techniques to improve performance after prediction) are supported for all models.
66

77
## Statistical Models
88

libemg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
from libemg import gui
1212
from libemg import shared_memory_manager
1313
from libemg import environments
14+
from libemg import _discrete_models

libemg/_discrete_models/DTW.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from tslearn.metrics import dtw_path
2+
import numpy as np
3+
4+
5+
class DTWClassifier:
6+
"""Dynamic Time Warping k-Nearest Neighbors classifier.
7+
8+
A classifier that uses Dynamic Time Warping (DTW) distance for template
9+
matching with k-nearest neighbors. Suitable for discrete gesture recognition
10+
where temporal alignment between samples varies.
11+
12+
Parameters
13+
----------
14+
n_neighbors: int, default=1
15+
Number of neighbors to use for k-nearest neighbors voting.
16+
17+
Attributes
18+
----------
19+
templates: list of ndarray
20+
The training templates stored after fitting.
21+
labels: ndarray
22+
The labels corresponding to each template.
23+
classes_: ndarray
24+
The unique class labels known to the classifier.
25+
"""
26+
27+
def __init__(self, n_neighbors=1):
28+
"""Initialize the DTW classifier.
29+
30+
Parameters
31+
----------
32+
n_neighbors: int, default=1
33+
Number of neighbors to use for k-nearest neighbors voting.
34+
"""
35+
self.n_neighbors = n_neighbors
36+
self.templates = None
37+
self.labels = None
38+
self.classes_ = None
39+
40+
def fit(self, features, labels):
41+
"""Fit the DTW classifier by storing training templates.
42+
43+
Parameters
44+
----------
45+
features: list of ndarray
46+
A list of training samples (templates) where each sample is
47+
a 2D array of shape (n_frames, n_features).
48+
labels: array-like
49+
The target labels for each template.
50+
"""
51+
self.templates = features
52+
self.labels = np.array(labels)
53+
self.classes_ = np.unique(labels)
54+
55+
def predict(self, samples):
56+
"""Predict class labels for samples.
57+
58+
Parameters
59+
----------
60+
samples: list of ndarray
61+
A list of samples to classify where each sample is a 2D array
62+
of shape (n_frames, n_features).
63+
64+
Returns
65+
-------
66+
ndarray
67+
Predicted class labels for each sample.
68+
"""
69+
# We can reuse predict_proba logic to get the class with highest probability
70+
probas = self.predict_proba(samples)
71+
return self.classes_[np.argmax(probas, axis=1)]
72+
73+
def predict_proba(self, samples, gamma=None, eps=1e-12):
74+
"""Predict class probabilities using DTW distance-weighted voting.
75+
76+
Computes DTW distances to all templates, selects k-nearest neighbors,
77+
and computes class probabilities using exponentially weighted voting.
78+
79+
Parameters
80+
----------
81+
samples: list of ndarray
82+
A list of samples to classify where each sample is a 2D array
83+
of shape (n_frames, n_features).
84+
gamma: float, default=None
85+
The kernel bandwidth for distance weighting. If None, automatically
86+
computed based on median neighbor distance.
87+
eps: float, default=1e-12
88+
Small constant to prevent division by zero.
89+
90+
Returns
91+
-------
92+
ndarray
93+
Predicted class probabilities of shape (n_samples, n_classes).
94+
"""
95+
if self.templates is None:
96+
raise ValueError("Call fit() before predict_proba().")
97+
98+
X = np.asarray(samples, dtype=object)
99+
out = np.zeros((len(X), len(self.classes_)), dtype=float)
100+
101+
for i, s in enumerate(X):
102+
# DTW distances to templates
103+
dists = np.array([dtw_path(t, s)[1] for t in self.templates], dtype=float)
104+
105+
# kNN
106+
nn_idx = np.argsort(dists)[:self.n_neighbors]
107+
nn_dists = dists[nn_idx]
108+
nn_labels = self.labels[nn_idx]
109+
110+
# choose gamma if not provided (scale to typical distance)
111+
g = gamma
112+
if g is None:
113+
scale = np.median(nn_dists) if len(nn_dists) else 1.0
114+
g = 1.0 / max(scale, eps)
115+
116+
weights = np.exp(-g * nn_dists) # closer -> bigger weight
117+
118+
# accumulate per class
119+
for cls_j, cls in enumerate(self.classes_):
120+
out[i, cls_j] = weights[nn_labels == cls].sum()
121+
122+
# normalize to probabilities
123+
z = out[i].sum()
124+
out[i] = out[i] / max(z, eps)
125+
126+
return out

libemg/_discrete_models/MVLDA.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
2+
import numpy as np
3+
from scipy import stats
4+
5+
6+
class MVLDA:
7+
"""Majority Vote Linear Discriminant Analysis classifier.
8+
9+
A classifier that uses Linear Discriminant Analysis (LDA) on individual frames
10+
and aggregates predictions using majority voting. This is designed for discrete
11+
gesture recognition where each sample contains multiple frames.
12+
13+
Attributes
14+
----------
15+
model: LinearDiscriminantAnalysis
16+
The underlying LDA model.
17+
classes_: ndarray
18+
The class labels known to the classifier.
19+
"""
20+
21+
def __init__(self):
22+
"""Initialize the MVLDA classifier."""
23+
self.model = None
24+
self.classes_ = None
25+
26+
def fit(self, x, y):
27+
"""Fit the MVLDA classifier on training data.
28+
29+
Parameters
30+
----------
31+
x: list of ndarray
32+
A list of samples where each sample is a 2D array of shape
33+
(n_frames, n_features). Each sample can have a different number of frames.
34+
y: array-like
35+
The target labels for each sample.
36+
"""
37+
self.model = LinearDiscriminantAnalysis()
38+
# Create a flat array of labels corresponding to every frame in x
39+
labels = np.hstack([[v] * x[i].shape[0] for i, v in enumerate(y)])
40+
self.model.fit(np.vstack(x), labels)
41+
# Store classes for consistent probability mapping
42+
self.classes_ = self.model.classes_
43+
44+
def predict(self, y):
45+
"""Predict class labels using majority voting.
46+
47+
Performs frame-level LDA predictions and returns the majority vote
48+
for each sample.
49+
50+
Parameters
51+
----------
52+
y: list of ndarray
53+
A list of samples where each sample is a 2D array of shape
54+
(n_frames, n_features).
55+
56+
Returns
57+
-------
58+
ndarray
59+
Predicted class labels for each sample.
60+
"""
61+
preds = []
62+
for s in y:
63+
frame_predictions = self.model.predict(s)
64+
# Majority vote on the labels
65+
majority_vote = stats.mode(frame_predictions, keepdims=False)[0]
66+
preds.append(majority_vote)
67+
return np.array(preds)
68+
69+
def predict_proba(self, y):
70+
"""Predict class probabilities using soft voting.
71+
72+
Calculates probabilities by averaging the frame-level probabilities
73+
for each sample (soft voting).
74+
75+
Parameters
76+
----------
77+
y: list of ndarray
78+
A list of samples where each sample is a 2D array of shape
79+
(n_frames, n_features).
80+
81+
Returns
82+
-------
83+
ndarray
84+
Predicted class probabilities of shape (n_samples, n_classes).
85+
"""
86+
probas = []
87+
for s in y:
88+
# Get probabilities for each frame: shape (n_frames, n_classes)
89+
frame_probas = self.model.predict_proba(s)
90+
91+
# Average probabilities across all frames in this sample
92+
sample_proba = np.mean(frame_probas, axis=0)
93+
probas.append(sample_proba)
94+
95+
return np.array(probas)

0 commit comments

Comments
 (0)