Skip to content

Commit dd61fe5

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents e7a9cd4 + 7ab0ff9 commit dd61fe5

13 files changed

Lines changed: 78 additions & 29 deletions

File tree

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
create:
8+
tags:
9+
- '**'
10+
pull_request:
11+
schedule:
12+
# Run every 1st of the month at 7:42am UTC.
13+
- cron: '42 7 1 * *'
14+
15+
jobs:
16+
find-benchmarks:
17+
runs-on: ubuntu-latest
18+
outputs:
19+
benchmark-dirs: ${{ steps.find-dirs.outputs.dirs }}
20+
steps:
21+
- uses: actions/checkout@v3
22+
- name: Find benchmark directories
23+
id: find-dirs
24+
run: |
25+
# Find all directories containing an objective.py file
26+
dirs=$(find . -maxdepth 2 -name "objective.py" -type f | xargs dirname | sed 's|./||' | jq -R -s -c 'split("\n")[:-1]')
27+
echo "dirs=$dirs" >> $GITHUB_OUTPUT
28+
29+
benchopt_dev:
30+
needs: find-benchmarks
31+
strategy:
32+
matrix:
33+
benchmark_dir: ${{ fromJson(needs.find-benchmarks.outputs.benchmark-dirs) }}
34+
uses: benchopt/template_benchmark/.github/workflows/test_benchmarks.yml@main
35+
with:
36+
benchopt_branch: benchopt@main
37+
benchmark_dir: ${{ matrix.benchmark_dir }}
38+
39+
lint:
40+
uses: benchopt/template_benchmark/.github/workflows/lint_benchmarks.yml@main

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
**/__pycache__/
33
**/__cache__/
44
**/data/
5+
**/outputs/
56

67
# IDE config folders
78
**/.vscode/

benchmark_template/objective.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from benchopt import BaseObjective
22

3-
import time
43
import deepinv as dinv
54
from torch.utils.data import DataLoader
65

@@ -16,7 +15,9 @@ class Objective(BaseObjective):
1615

1716
# Minimal version of benchopt required to run this benchmark.
1817
# Bump it up if the benchmark depends on a new feature of benchopt.
19-
min_benchopt_version = "1.7"
18+
min_benchopt_version = "1.8"
19+
20+
sampling_strategy = 'run_once'
2021

2122
def set_data(self, dataset, physics):
2223
self.dataset = dataset

benchmark_template/solvers/solver1.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@ class Solver(BaseSolver):
1010
# add any hyper-parameters here
1111
parameters = {}
1212

13-
sampling_strategy = 'run_once'
14-
15-
requirements = []
16-
1713
def set_objective(self, train_dataset=None, physics=None):
1814
device = (
1915
dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

cbsd500_gaussian_denoising/datasets/cbsd68.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,18 @@
1313
class Dataset(BaseDataset):
1414

1515
name = "CBSD68"
16-
1716
parameters = {
1817
'physics' : ['Denoising'],
1918
'noise' : ['GaussianNoise'],
20-
'img_size': [256],
2119
'sigma': [0.1],
20+
'img_size': [256],
2221
'debug': [False],
2322
}
2423

24+
test_parameters = {
25+
"debug": [True]
26+
}
27+
2528
requirements = ["datasets"]
2629

2730
def get_data(self):

cbsd500_gaussian_denoising/objective.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from benchopt import BaseObjective
22

3-
import time
43
import deepinv as dinv
54
from torch.utils.data import DataLoader
65

@@ -16,7 +15,10 @@ class Objective(BaseObjective):
1615

1716
# Minimal version of benchopt required to run this benchmark.
1817
# Bump it up if the benchmark depends on a new feature of benchopt.
19-
min_benchopt_version = "1.7"
18+
min_benchopt_version = "1.8"
19+
20+
# Deactivate multiple runs for each solver
21+
sampling_strategy = "run_once"
2022

2123
def set_data(self, dataset, physics):
2224
self.dataset = dataset
@@ -44,7 +46,14 @@ def evaluate_result(self, model):
4446
return results
4547

4648
def get_one_result(self):
47-
return dict(model=lambda x: x)
49+
50+
class DummyModel:
51+
def eval(self): pass
52+
53+
def __call__(self, x, physics=None):
54+
return physics.A_adjoint(x)
55+
56+
return dict(model=DummyModel())
4857

4958
def get_objective(self):
5059
return dict(

cbsd500_gaussian_denoising/solvers/drunet.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@ class Solver(BaseSolver):
88

99
parameters = {}
1010

11-
sampling_strategy = 'run_once'
12-
13-
requirements = []
14-
1511
def set_objective(self, train_dataset=None, physics=None):
1612
self.model = dinv.models.ArtifactRemoval(
1713
dinv.models.DRUNet()

cbsd500_gaussian_denoising/solvers/restformer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@ class Solver(BaseSolver):
88

99
parameters = {}
1010

11-
sampling_strategy = 'run_once'
12-
13-
requirements = []
14-
1511
def set_objective(self, train_dataset=None, physics=None):
1612
self.model = dinv.models.ArtifactRemoval(
1713
dinv.models.Restormer()

div2k_gaussian_deblurring/datasets/div2k.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class Dataset(BaseDataset):
2121
'debug': [False],
2222
}
2323

24+
test_parameters = {
25+
"debug": [True]
26+
}
27+
2428
requirements = ["datasets"]
2529

2630
def get_data(self):

div2k_gaussian_deblurring/objective.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ class Objective(BaseObjective):
1515

1616
# Minimal version of benchopt required to run this benchmark.
1717
# Bump it up if the benchmark depends on a new feature of benchopt.
18-
min_benchopt_version = "1.7"
18+
min_benchopt_version = "1.8"
19+
20+
# Deactivate multiple runs for each solver
21+
sampling_strategy = "run_once"
1922

2023
def set_data(self, dataset, physics):
2124
self.dataset = dataset
@@ -44,7 +47,14 @@ def evaluate_result(self, model):
4447
return results
4548

4649
def get_one_result(self):
47-
return dict(model=lambda x: x)
50+
51+
class DummyModel:
52+
def eval(self): pass
53+
54+
def __call__(self, x, physics=None):
55+
return physics.A_adjoint(x)
56+
57+
return dict(model=DummyModel())
4858

4959
def get_objective(self):
5060
return dict(

0 commit comments

Comments
 (0)