Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ dropSubscription
// ---- Create Model
createModel
: CREATE MODEL modelName=identifier uriClause
| CREATE MODEL modelType=identifier modelId=identifier (WITH HYPERPARAMETERS LR_BRACKET hparamPair (COMMA hparamPair)* RR_BRACKET)? (FROM MODEL existingModelId=identifier)? ON DATASET LR_BRACKET trainingData RR_BRACKET
| CREATE MODEL modelId=identifier (WITH HYPERPARAMETERS LR_BRACKET hparamPair (COMMA hparamPair)* RR_BRACKET)? FROM MODEL existingModelId=identifier ON DATASET LR_BRACKET trainingData RR_BRACKET
;

trainingData
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2660,7 +2660,6 @@ public TSStatus createTraining(TCreateTrainingReq req) {

TTrainingReq trainingReq = new TTrainingReq();
trainingReq.setModelId(req.getModelId());
trainingReq.setModelType(req.getModelType());
if (req.isSetExistingModelId()) {
trainingReq.setExistingModelId(req.getExistingModelId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1359,11 +1359,7 @@ protected IConfigTask visitCreateTraining(CreateTraining node, MPPQueryContext c
context.setQueryType(QueryType.WRITE);

return new CreateTrainingTask(
node.getModelId(),
node.getModelType(),
node.getParameters(),
node.getExistingModelId(),
node.getTargetSql());
node.getModelId(), node.getParameters(), node.getExistingModelId(), node.getTargetSql());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,6 @@ public IConfigTask visitCreateTraining(
}
return new CreateTrainingTask(
createTrainingStatement.getModelId(),
createTrainingStatement.getModelType(),
createTrainingStatement.getParameters(),
createTrainingStatement.getTargetTimeRanges(),
createTrainingStatement.getExistingModelId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3331,7 +3331,6 @@ public SettableFuture<ConfigTaskResult> showModels(final String modelName) {
@Override
public SettableFuture<ConfigTaskResult> createTraining(
String modelId,
String modelType,
boolean isTableModel,
Map<String, String> parameters,
List<List<Long>> timeRanges,
Expand All @@ -3341,7 +3340,7 @@ public SettableFuture<ConfigTaskResult> createTraining(
final SettableFuture<ConfigTaskResult> future = SettableFuture.create();
try (final ConfigNodeClient client =
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
final TCreateTrainingReq req = new TCreateTrainingReq(modelId, modelType, isTableModel);
final TCreateTrainingReq req = new TCreateTrainingReq(modelId, isTableModel, existingModelId);

if (isTableModel) {
TDataSchemaForTable dataSchemaForTable = new TDataSchemaForTable();
Expand All @@ -3354,7 +3353,6 @@ public SettableFuture<ConfigTaskResult> createTraining(
}
req.setParameters(parameters);
req.setTimeRanges(timeRanges);
req.setExistingModelId(existingModelId);
final TSStatus executionStatus = client.createTraining(req);
if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != executionStatus.getCode()) {
future.setException(new IoTDBException(executionStatus.message, executionStatus.code));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ SettableFuture<ConfigTaskResult> createModel(

SettableFuture<ConfigTaskResult> createTraining(
String modelId,
String modelType,
boolean isTableModel,
Map<String, String> parameters,
List<List<Long>> timeRanges,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
public class CreateTrainingTask implements IConfigTask {

private final String modelId;
private final String modelType;
private final boolean isTableModel;
private final Map<String, String> parameters;

Expand All @@ -45,13 +44,8 @@ public class CreateTrainingTask implements IConfigTask {

// For table model
public CreateTrainingTask(
String modelId,
String modelType,
Map<String, String> parameters,
String existingModelId,
String targetSql) {
String modelId, Map<String, String> parameters, String existingModelId, String targetSql) {
this.modelId = modelId;
this.modelType = modelType;
this.parameters = parameters;
this.existingModelId = existingModelId;
this.targetSql = targetSql;
Expand All @@ -61,13 +55,11 @@ public CreateTrainingTask(
// For tree model
public CreateTrainingTask(
String modelId,
String modelType,
Map<String, String> parameters,
List<List<Long>> timeRanges,
String existingModelId,
List<String> targetPaths) {
this.modelId = modelId;
this.modelType = modelType;
this.parameters = parameters;
this.timeRanges = timeRanges;
this.existingModelId = existingModelId;
Expand All @@ -80,13 +72,6 @@ public CreateTrainingTask(
public ListenableFuture<ConfigTaskResult> execute(IConfigTaskExecutor configTaskExecutor)
throws InterruptedException {
return configTaskExecutor.createTraining(
modelId,
modelType,
isTableModel,
parameters,
timeRanges,
existingModelId,
targetSql,
targetPaths);
modelId, isTableModel, parameters, timeRanges, existingModelId, targetSql, targetPaths);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1368,9 +1368,7 @@ public static void validateModelName(String modelName) {
public Statement visitCreateModel(IoTDBSqlParser.CreateModelContext ctx) {
if (ctx.modelName == null) {
String modelId = ctx.modelId.getText();
String modelType = ctx.modelType.getText();
CreateTrainingStatement createTrainingStatement =
new CreateTrainingStatement(modelId, modelType);
CreateTrainingStatement createTrainingStatement = new CreateTrainingStatement(modelId);
if (ctx.hparamPair() != null) {
Map<String, String> parameterList = new HashMap<>();
for (IoTDBSqlParser.HparamPairContext hparamPairContext : ctx.hparamPair()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,14 @@
public class CreateTraining extends Statement {

private final String modelId;
private final String modelType;
private final String targetSql;

private Map<String, String> parameters;
private String existingModelId = null;

public CreateTraining(String modelId, String modelType, String targetSql) {
public CreateTraining(String modelId, String targetSql) {
super(null);
this.modelId = modelId;
this.modelType = modelType;
this.targetSql = targetSql;
}

Expand All @@ -56,10 +54,6 @@ public String getModelId() {
return modelId;
}

public String getModelType() {
return modelType;
}

public Map<String, String> getParameters() {
return parameters;
}
Expand All @@ -79,7 +73,7 @@ public List<? extends Node> getChildren() {

@Override
public int hashCode() {
return Objects.hash(modelId, modelType, targetSql, existingModelId, parameters);
return Objects.hash(modelId, targetSql, existingModelId, parameters);
}

@Override
Expand All @@ -89,7 +83,6 @@ public boolean equals(Object obj) {
}
CreateTraining createTraining = (CreateTraining) obj;
return modelId.equals(createTraining.modelId)
&& modelType.equals(createTraining.modelType)
&& Objects.equals(existingModelId, createTraining.existingModelId)
&& Objects.equals(parameters, createTraining.parameters)
&& Objects.equals(targetSql, createTraining.targetSql);
Expand All @@ -101,9 +94,6 @@ public String toString() {
+ "modelId='"
+ modelId
+ '\''
+ ", modelType='"
+ modelType
+ '\''
+ ", parameters="
+ parameters
+ ", existingModelId='"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3558,13 +3558,12 @@ public static void validateModelName(String modelName) {
public Node visitCreateModelStatement(RelationalSqlParser.CreateModelStatementContext ctx) {
String modelId = ctx.modelId.getText();
validateModelName(modelId);
String modelType = ctx.modelType.getText();

if (ctx.targetData == null) {
throw new SemanticException("Target data in sql should be set in CREATE MODEL");
}
String targetData = ((StringLiteral) visit(ctx.targetData)).getValue();
CreateTraining createTraining = new CreateTraining(modelId, modelType, targetData);
CreateTraining createTraining = new CreateTraining(modelId, targetData);
if (ctx.HYPERPARAMETERS() != null) {
Map<String, String> parameters = new HashMap<>();
for (RelationalSqlParser.HparamPairContext hparamPairContext : ctx.hparamPair()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,15 @@
public class CreateTrainingStatement extends Statement implements IConfigStatement {

private final String modelId;
private final String modelType;

private Map<String, String> parameters;
private String existingModelId = null;

private List<PartialPath> targetPathPatterns;
private List<List<Long>> targetTimeRanges;

public CreateTrainingStatement(String modelId, String modelType) {
public CreateTrainingStatement(String modelId) {
this.modelId = modelId;
this.modelType = modelType;
}

public void setTargetPathPatterns(List<PartialPath> targetPathPatterns) {
Expand All @@ -65,10 +63,6 @@ public String getModelId() {
return modelId;
}

public String getModelType() {
return modelType;
}

public void setExistingModelId(String existingModelId) {
this.existingModelId = existingModelId;
}
Expand All @@ -87,7 +81,7 @@ public void setParameters(Map<String, String> parameters) {

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), modelId, modelType, existingModelId, parameters);
return Objects.hash(super.hashCode(), modelId, existingModelId, parameters);
}

@Override
Expand All @@ -97,7 +91,6 @@ public boolean equals(Object obj) {
}
CreateTrainingStatement target = (CreateTrainingStatement) obj;
return modelId.equals(target.modelId)
&& modelType.equals(target.modelType)
&& Objects.equals(existingModelId, target.existingModelId)
&& Objects.equals(parameters, target.parameters);
}
Expand All @@ -108,9 +101,6 @@ public String toString() {
+ "modelId='"
+ modelId
+ '\''
+ ", modelType='"
+ modelType
+ '\''
+ ", parameters="
+ parameters
+ ", existingModelId='"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ revokeGrantOpt
// ------------------------------------------- AI ---------------------------------------------------------

createModelStatement
: CREATE MODEL modelType=identifier modelId=identifier (WITH HYPERPARAMETERS '(' hparamPair (',' hparamPair)* ')')? (FROM MODEL existingModelId=identifier)? ON DATASET '(' targetData=string ')'
: CREATE MODEL modelId=identifier (WITH HYPERPARAMETERS '(' hparamPair (',' hparamPair)* ')')? FROM MODEL existingModelId=identifier ON DATASET '(' targetData=string ')'
;

hparamPair
Expand Down
3 changes: 1 addition & 2 deletions iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,9 @@ struct IDataSchema {
struct TTrainingReq {
1: required string dbType
2: required string modelId
3: required string modelType
3: required string existingModelId
4: optional list<IDataSchema> targetDataSchema;
5: optional map<string, string> parameters;
6: optional string existingModelId
}

struct TForecastReq {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1098,13 +1098,12 @@ struct TDataSchemaForTree{

struct TCreateTrainingReq {
1: required string modelId
2: required string modelType
3: required bool isTableModel
2: required bool isTableModel
3: required string existingModelId
4: optional TDataSchemaForTable dataSchemaForTable
5: optional TDataSchemaForTree dataSchemaForTree
6: optional map<string, string> parameters
7: optional string existingModelId
8: optional list<list<i64>> timeRanges
7: optional list<list<i64>> timeRanges
}

// ====================================================
Expand Down
Loading