Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions RATapi/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
Field,
ValidationError,
ValidatorFunctionWrapHandler,
field_validator,
model_serializer,
model_validator,
)

from RATapi.utils.custom_errors import custom_pydantic_validation_error
from RATapi.utils.enums import BoundHandling, Display, Parallel, Procedures, Strategies

common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleParams", "display"]
common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleMinAngle", "resampleNPoints", "display"]
update_fields = ["updateFreq", "updatePlotFreq"]
fields = {
"calculate": common_fields,
Expand All @@ -41,7 +40,8 @@ class Controls(BaseModel, validate_assignment=True, extra="forbid"):
procedure: Procedures = Procedures.Calculate
parallel: Parallel = Parallel.Single
calcSldDuringFit: bool = False
resampleParams: list[float] = Field([0.9, 50], min_length=2, max_length=2)
resampleMinAngle: float = Field(0.9, le=1, gt=0)
resampleNPoints: int = Field(50, gt=0)
display: Display = Display.Iter
# Simplex
xTolerance: float = Field(1.0e-6, gt=0.0)
Expand Down Expand Up @@ -117,16 +117,6 @@ def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandle

return validated_self

@field_validator("resampleParams")
@classmethod
def check_resample_params(cls, values: list[float]) -> list[float]:
"""Make sure each of the two values of resampleParams satisfy their conditions."""
if not 0 < values[0] < 1:
raise ValueError("resampleParams[0] must be between 0 and 1")
if values[1] < 0:
raise ValueError("resampleParams[1] must be greater than or equal to 0")
return values

@model_serializer
def serialize(self):
"""Filter fields so only those applying to the chosen procedure are serialized."""
Expand Down
2 changes: 1 addition & 1 deletion RATapi/examples/absorption/absorption.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def absorption():
)

# Now make a controls block and run the code
controls = RAT.Controls(parallel="contrasts", resampleParams=[0.9, 150.0])
controls = RAT.Controls(parallel="contrasts", resampleNPoints=150)
problem, results = RAT.run(problem, controls)

return problem, results
Expand Down
3 changes: 2 additions & 1 deletion RATapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ def make_controls(input_controls: RATapi.Controls, checks: Checks) -> Control:
controls.procedure = input_controls.procedure
controls.parallel = input_controls.parallel
controls.calcSldDuringFit = input_controls.calcSldDuringFit
controls.resampleParams = input_controls.resampleParams
controls.resampleMinAngle = input_controls.resampleMinAngle
controls.resampleNPoints = input_controls.resampleNPoints
controls.display = input_controls.display
# Simplex
controls.xTolerance = input_controls.xTolerance
Expand Down
57 changes: 30 additions & 27 deletions cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,8 @@ struct Control {
real_T propScale {};
real_T nsTolerance {};
boolean_T calcSldDuringFit {};
py::array_t<real_T> resampleParams;
real_T resampleMinAngle {};
real_T resampleNPoints {};
real_T updateFreq {};
real_T updatePlotFreq {};
real_T nSamples {};
Expand Down Expand Up @@ -914,8 +915,8 @@ RAT::struct2_T createStruct2T(const Control& control)
stringToRatArray(control.procedure, control_struct.procedure.data, control_struct.procedure.size);
stringToRatArray(control.display, control_struct.display.data, control_struct.display.size);
control_struct.xTolerance = control.xTolerance;
control_struct.resampleParams[0] = control.resampleParams.at(0);
control_struct.resampleParams[1] = control.resampleParams.at(1);
control_struct.resampleMinAngle = control.resampleMinAngle;
control_struct.resampleNPoints = control.resampleNPoints;
stringToRatArray(control.boundHandling, control_struct.boundHandling.data, control_struct.boundHandling.size);
control_struct.adaptPCR = control.adaptPCR;
control_struct.checks = createStruct3(control.checks);
Expand Down Expand Up @@ -1616,7 +1617,8 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("propScale", &Control::propScale)
.def_readwrite("nsTolerance", &Control::nsTolerance)
.def_readwrite("calcSldDuringFit", &Control::calcSldDuringFit)
.def_readwrite("resampleParams", &Control::resampleParams)
.def_readwrite("resampleMinAngle", &Control::resampleMinAngle)
.def_readwrite("resampleNPoints", &Control::resampleNPoints)
.def_readwrite("updateFreq", &Control::updateFreq)
.def_readwrite("updatePlotFreq", &Control::updatePlotFreq)
.def_readwrite("nSamples", &Control::nSamples)
Expand All @@ -1633,14 +1635,14 @@ PYBIND11_MODULE(rat_core, m) {
return py::make_tuple(ctrl.parallel, ctrl.procedure, ctrl.display, ctrl.xTolerance, ctrl.funcTolerance,
ctrl.maxFuncEvals, ctrl.maxIterations, ctrl.populationSize, ctrl.fWeight, ctrl.crossoverProbability,
ctrl.targetValue, ctrl.numGenerations, ctrl.strategy, ctrl.nLive, ctrl.nMCMC, ctrl.propScale,
ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleParams, ctrl.updateFreq, ctrl.updatePlotFreq,
ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma, ctrl.boundHandling, ctrl.adaptPCR,
ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam, ctrl.checks.fitQzshift,
ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut,
ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleMinAngle, ctrl.resampleNPoints,
ctrl.updateFreq, ctrl.updatePlotFreq, ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma,
ctrl.boundHandling, ctrl.adaptPCR, ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam,
ctrl.checks.fitQzshift, ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut,
ctrl.checks.fitResolutionParam, ctrl.checks.fitDomainRatio);
},
[](py::tuple t) { // __setstate__
if (t.size() != 36)
if (t.size() != 37)
throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!");

/* Create a new C++ instance */
Expand All @@ -1664,25 +1666,26 @@ PYBIND11_MODULE(rat_core, m) {
ctrl.propScale = t[15].cast<real_T>();
ctrl.nsTolerance = t[16].cast<real_T>();
ctrl.calcSldDuringFit = t[17].cast<boolean_T>();
ctrl.resampleParams = t[18].cast<py::array_t<real_T>>();
ctrl.updateFreq = t[19].cast<real_T>();
ctrl.updatePlotFreq = t[20].cast<real_T>();
ctrl.nSamples = t[21].cast<real_T>();
ctrl.nChains = t[22].cast<real_T>();
ctrl.jumpProbability = t[23].cast<real_T>();
ctrl.pUnitGamma = t[24].cast<real_T>();
ctrl.boundHandling = t[25].cast<std::string>();
ctrl.adaptPCR = t[26].cast<boolean_T>();
ctrl.IPCFilePath = t[27].cast<std::string>();
ctrl.resampleMinAngle = t[18].cast<real_T>();
ctrl.resampleNPoints = t[19].cast<real_T>();
ctrl.updateFreq = t[20].cast<real_T>();
ctrl.updatePlotFreq = t[21].cast<real_T>();
ctrl.nSamples = t[22].cast<real_T>();
ctrl.nChains = t[23].cast<real_T>();
ctrl.jumpProbability = t[24].cast<real_T>();
ctrl.pUnitGamma = t[25].cast<real_T>();
ctrl.boundHandling = t[26].cast<std::string>();
ctrl.adaptPCR = t[27].cast<boolean_T>();
ctrl.IPCFilePath = t[28].cast<std::string>();

ctrl.checks.fitParam = t[28].cast<py::array_t<real_T>>();
ctrl.checks.fitBackgroundParam = t[29].cast<py::array_t<real_T>>();
ctrl.checks.fitQzshift = t[30].cast<py::array_t<real_T>>();
ctrl.checks.fitScalefactor = t[31].cast<py::array_t<real_T>>();
ctrl.checks.fitBulkIn = t[32].cast<py::array_t<real_T>>();
ctrl.checks.fitBulkOut = t[33].cast<py::array_t<real_T>>();
ctrl.checks.fitResolutionParam = t[34].cast<py::array_t<real_T>>();
ctrl.checks.fitDomainRatio = t[35].cast<py::array_t<real_T>>();
ctrl.checks.fitParam = t[29].cast<py::array_t<real_T>>();
ctrl.checks.fitBackgroundParam = t[30].cast<py::array_t<real_T>>();
ctrl.checks.fitQzshift = t[31].cast<py::array_t<real_T>>();
ctrl.checks.fitScalefactor = t[32].cast<py::array_t<real_T>>();
ctrl.checks.fitBulkIn = t[33].cast<py::array_t<real_T>>();
ctrl.checks.fitBulkOut = t[34].cast<py::array_t<real_T>>();
ctrl.checks.fitResolutionParam = t[35].cast<py::array_t<real_T>>();
ctrl.checks.fitDomainRatio = t[36].cast<py::array_t<real_T>>();

return ctrl;
}));
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from setuptools.command.build_clib import build_clib
from setuptools.command.build_ext import build_ext

__version__ = "0.0.0.dev1"
__version__ = "0.0.0.dev2"
PACKAGE_NAME = "RATapi"

with open("README.md") as f:
Expand Down
Loading