671 lines
31 KiB
C#
671 lines
31 KiB
C#
using Microsoft.Extensions.Logging;
|
|
using Microsoft.Data.SqlClient;
|
|
using OrpaonVision.Core.Results;
|
|
using OrpaonVision.Core.Training;
|
|
using OrpaonVision.Model.Training;
|
|
using System.Data;
|
|
|
|
namespace OrpaonVision.ConfigApp.Infrastructure.Persistence;
|
|
|
|
/// <summary>
|
|
/// SQL Server 训练任务仓储实现。
|
|
/// </summary>
|
|
public sealed class SqlTrainingTaskStore : ITrainingTaskStore
|
|
{
|
|
private readonly ILogger<SqlTrainingTaskStore> _logger;
|
|
private readonly string _connectionString;
|
|
|
|
/// <summary>
|
|
/// 构造函数。
|
|
/// </summary>
|
|
public SqlTrainingTaskStore(ILogger<SqlTrainingTaskStore> logger, string connectionString)
|
|
{
|
|
_logger = logger;
|
|
_connectionString = connectionString;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<TrainingTaskModel>> CreateAsync(TrainingTaskModel task)
|
|
{
|
|
try
|
|
{
|
|
const string sql = @"
|
|
INSERT INTO training_tasks (
|
|
Id, Name, Description, DatasetPath, ModelPath, ConfigJson,
|
|
Status, Priority, AssignedToId, AssignedToName, AssignedBy,
|
|
CurrentEpoch, TotalEpochs, CurrentLoss, CurrentMap, Progress,
|
|
StartedAtUtc, CompletedAtUtc, OutputModelPath, FinalLoss, FinalMap,
|
|
ErrorMessage, CreatedAtUtc, UpdatedAtUtc, CreatedBy, UpdatedBy, Remark
|
|
) VALUES (
|
|
@Id, @Name, @Description, @DatasetPath, @ModelPath, @ConfigJson,
|
|
@Status, @Priority, @AssignedToId, @AssignedToName, @AssignedBy,
|
|
@CurrentEpoch, @TotalEpochs, @CurrentLoss, @CurrentMap, @Progress,
|
|
@StartedAtUtc, @CompletedAtUtc, @OutputModelPath, @FinalLoss, @FinalMap,
|
|
@ErrorMessage, @CreatedAtUtc, @UpdatedAtUtc, @CreatedBy, @UpdatedBy, @Remark
|
|
)";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
AddTaskParameters(command, task);
|
|
|
|
await command.ExecuteNonQueryAsync();
|
|
|
|
_logger.LogInformation("训练任务创建成功: {TaskId} - {TaskName}", task.Id, task.Name);
|
|
return Result<TrainingTaskModel>.Success(task);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "创建训练任务失败: {TaskName}", task.Name);
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_CREATE_FAILED", "创建训练任务失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<TrainingTaskModel>> UpdateAsync(TrainingTaskModel task)
|
|
{
|
|
try
|
|
{
|
|
const string sql = @"
|
|
UPDATE training_tasks SET
|
|
Name = @Name, Description = @Description, DatasetPath = @DatasetPath, ModelPath = @ModelPath, ConfigJson = @ConfigJson,
|
|
Status = @Status, Priority = @Priority, AssignedToId = @AssignedToId, AssignedToName = @AssignedToName,
|
|
CurrentEpoch = @CurrentEpoch, TotalEpochs = @TotalEpochs, CurrentLoss = @CurrentLoss, CurrentMap = @CurrentMap, Progress = @Progress,
|
|
StartedAtUtc = @StartedAtUtc, CompletedAtUtc = @CompletedAtUtc, OutputModelPath = @OutputModelPath, FinalLoss = @FinalLoss, FinalMap = @FinalMap,
|
|
ErrorMessage = @ErrorMessage, UpdatedAtUtc = @UpdatedAtUtc, UpdatedBy = @UpdatedBy, Remark = @Remark
|
|
WHERE Id = @Id";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
AddTaskParameters(command, task);
|
|
|
|
var rowsAffected = await command.ExecuteNonQueryAsync();
|
|
if (rowsAffected == 0)
|
|
{
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_NOT_FOUND", "训练任务不存在");
|
|
}
|
|
|
|
_logger.LogInformation("训练任务更新成功: {TaskId} - {TaskName}", task.Id, task.Name);
|
|
return Result<TrainingTaskModel>.Success(task);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "更新训练任务失败: {TaskId}", task.Id);
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_UPDATE_FAILED", "更新训练任务失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result> DeleteAsync(Guid id)
|
|
{
|
|
try
|
|
{
|
|
const string sql = "DELETE FROM training_tasks WHERE Id = @Id";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
command.Parameters.AddWithValue("@Id", id);
|
|
|
|
await command.ExecuteNonQueryAsync();
|
|
|
|
_logger.LogInformation("训练任务删除成功: {TaskId}", id);
|
|
return Result.Success("训练任务删除成功");
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "删除训练任务失败: {TaskId}", id);
|
|
return Result.Fail("TRAINING_TASK_DELETE_FAILED", "删除训练任务失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<TrainingTaskModel?>> GetByIdAsync(Guid id)
|
|
{
|
|
try
|
|
{
|
|
const string sql = @"
|
|
SELECT Id, Name, Description, DatasetPath, ModelPath, ConfigJson,
|
|
Status, Priority, AssignedToId, AssignedToName, AssignedBy,
|
|
CurrentEpoch, TotalEpochs, CurrentLoss, CurrentMap, Progress,
|
|
StartedAtUtc, CompletedAtUtc, OutputModelPath, FinalLoss, FinalMap,
|
|
ErrorMessage, CreatedAtUtc, UpdatedAtUtc, CreatedBy, UpdatedBy, Remark
|
|
FROM training_tasks
|
|
WHERE Id = @Id";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
command.Parameters.AddWithValue("@Id", id);
|
|
|
|
using var reader = await command.ExecuteReaderAsync();
|
|
if (await reader.ReadAsync())
|
|
{
|
|
return Result<TrainingTaskModel?>.Success(MapReaderToTask(reader));
|
|
}
|
|
|
|
return Result<TrainingTaskModel?>.Success(null);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "根据ID获取训练任务失败: {TaskId}", id);
|
|
return Result<TrainingTaskModel?>.Fail("TRAINING_TASK_GET_BY_ID_FAILED", "获取训练任务失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<(IReadOnlyList<TrainingTaskModel> tasks, int totalCount)>> GetPagedListAsync(
|
|
int pageIndex,
|
|
int pageSize,
|
|
TrainingTaskStatus? status = null,
|
|
string? keyword = null,
|
|
Guid? assignedToId = null)
|
|
{
|
|
try
|
|
{
|
|
var offset = pageIndex * pageSize;
|
|
|
|
var whereConditions = new List<string>();
|
|
var parameters = new List<SqlParameter>
|
|
{
|
|
new("@Offset", offset),
|
|
new("@PageSize", pageSize)
|
|
};
|
|
|
|
if (status.HasValue)
|
|
{
|
|
whereConditions.Add("Status = @Status");
|
|
parameters.Add(new SqlParameter("@Status", (int)status.Value));
|
|
}
|
|
|
|
if (!string.IsNullOrWhiteSpace(keyword))
|
|
{
|
|
whereConditions.Add("(Name LIKE @Keyword OR Description LIKE @Keyword)");
|
|
parameters.Add(new SqlParameter("@Keyword", $"%{keyword}%"));
|
|
}
|
|
|
|
if (assignedToId.HasValue)
|
|
{
|
|
whereConditions.Add("AssignedToId = @AssignedToId");
|
|
parameters.Add(new SqlParameter("@AssignedToId", assignedToId.Value));
|
|
}
|
|
|
|
var whereClause = whereConditions.Any() ? $"WHERE {string.Join(" AND ", whereConditions)}" : "";
|
|
|
|
// 查询总数
|
|
var countSql = $"SELECT COUNT(1) FROM training_tasks {whereClause}";
|
|
|
|
// 查询数据
|
|
var dataSql = $@"
|
|
SELECT Id, Name, Description, DatasetPath, ModelPath, ConfigJson,
|
|
Status, Priority, AssignedToId, AssignedToName, AssignedBy,
|
|
CurrentEpoch, TotalEpochs, CurrentLoss, CurrentMap, Progress,
|
|
StartedAtUtc, CompletedAtUtc, OutputModelPath, FinalLoss, FinalMap,
|
|
ErrorMessage, CreatedAtUtc, UpdatedAtUtc, CreatedBy, UpdatedBy, Remark
|
|
FROM training_tasks
|
|
{whereClause}
|
|
ORDER BY CreatedAtUtc DESC
|
|
OFFSET @Offset ROWS FETCH NEXT @PageSize ROWS ONLY";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
// 获取总数
|
|
using var countCommand = new SqlCommand(countSql, connection);
|
|
countCommand.Parameters.AddRange(parameters.ToArray());
|
|
var totalCount = (int)await countCommand.ExecuteScalarAsync();
|
|
|
|
// 获取数据
|
|
var tasks = new List<TrainingTaskModel>();
|
|
using var dataCommand = new SqlCommand(dataSql, connection);
|
|
dataCommand.Parameters.AddRange(parameters.ToArray());
|
|
|
|
using var reader = await dataCommand.ExecuteReaderAsync();
|
|
while (await reader.ReadAsync())
|
|
{
|
|
tasks.Add(MapReaderToTask(reader));
|
|
}
|
|
|
|
return Result<(IReadOnlyList<TrainingTaskModel>, int)>.Success((tasks.AsReadOnly(), totalCount));
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "获取训练任务分页列表失败");
|
|
return Result<(IReadOnlyList<TrainingTaskModel>, int)>.Fail("TRAINING_TASK_GET_PAGED_LIST_FAILED", "获取训练任务列表失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<IReadOnlyList<TrainingTaskModel>>> GetUserTasksAsync(Guid userId, TrainingTaskStatus? status = null)
|
|
{
|
|
try
|
|
{
|
|
var whereConditions = new List<string> { "AssignedToId = @UserId" };
|
|
var parameters = new List<SqlParameter>
|
|
{
|
|
new("@UserId", userId)
|
|
};
|
|
|
|
if (status.HasValue)
|
|
{
|
|
whereConditions.Add("Status = @Status");
|
|
parameters.Add(new SqlParameter("@Status", (int)status.Value));
|
|
}
|
|
|
|
var whereClause = string.Join(" AND ", whereConditions);
|
|
|
|
var sql = $@"
|
|
SELECT Id, Name, Description, DatasetPath, ModelPath, ConfigJson,
|
|
Status, Priority, AssignedToId, AssignedToName, AssignedBy,
|
|
CurrentEpoch, TotalEpochs, CurrentLoss, CurrentMap, Progress,
|
|
StartedAtUtc, CompletedAtUtc, OutputModelPath, FinalLoss, FinalMap,
|
|
ErrorMessage, CreatedAtUtc, UpdatedAtUtc, CreatedBy, UpdatedBy, Remark
|
|
FROM training_tasks
|
|
WHERE {whereClause}
|
|
ORDER BY CreatedAtUtc DESC";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
command.Parameters.AddRange(parameters.ToArray());
|
|
|
|
var tasks = new List<TrainingTaskModel>();
|
|
using var reader = await command.ExecuteReaderAsync();
|
|
while (await reader.ReadAsync())
|
|
{
|
|
tasks.Add(MapReaderToTask(reader));
|
|
}
|
|
|
|
return Result<IReadOnlyList<TrainingTaskModel>>.Success(tasks.AsReadOnly());
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "获取用户训练任务失败: {UserId}", userId);
|
|
return Result<IReadOnlyList<TrainingTaskModel>>.Fail("TRAINING_TASK_GET_USER_TASKS_FAILED", "获取用户训练任务失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<TrainingTaskStatistics>> GetStatisticsAsync()
|
|
{
|
|
try
|
|
{
|
|
const string sql = @"
|
|
SELECT
|
|
COUNT(*) as TotalTasks,
|
|
SUM(CASE WHEN Status = 0 THEN 1 ELSE 0 END) as DraftTasks,
|
|
SUM(CASE WHEN Status = 1 THEN 1 ELSE 0 END) as PendingTasks,
|
|
SUM(CASE WHEN Status = 2 THEN 1 ELSE 0 END) as RunningTasks,
|
|
SUM(CASE WHEN Status = 3 THEN 1 ELSE 0 END) as PausedTasks,
|
|
SUM(CASE WHEN Status = 4 THEN 1 ELSE 0 END) as CompletedTasks,
|
|
SUM(CASE WHEN Status = 5 THEN 1 ELSE 0 END) as FailedTasks,
|
|
SUM(CASE WHEN Status = 6 THEN 1 ELSE 0 END) as CancelledTasks,
|
|
SUM(CASE WHEN Status = 4 AND CAST(CompletedAtUtc AS DATE) = CAST(GETUTCDATE() AS DATE) THEN 1 ELSE 0 END) as TodayCompletedTasks,
|
|
SUM(CASE WHEN Status = 4 AND CompletedAtUtc >= DATEADD(day, -7, GETUTCDATE()) THEN 1 ELSE 0 END) as ThisWeekCompletedTasks,
|
|
SUM(CASE WHEN Status = 4 AND CompletedAtUtc >= DATEADD(day, -30, GETUTCDATE()) THEN 1 ELSE 0 END) as ThisMonthCompletedTasks,
|
|
AVG(CASE WHEN Status = 4 AND StartedAtUtc IS NOT NULL AND CompletedAtUtc IS NOT NULL
|
|
THEN DATEDIFF(minute, StartedAtUtc, CompletedAtUtc) ELSE NULL END) as AverageTrainingMinutes
|
|
FROM training_tasks";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
|
|
using var reader = await command.ExecuteReaderAsync();
|
|
if (await reader.ReadAsync())
|
|
{
|
|
var statistics = new TrainingTaskStatistics
|
|
{
|
|
TotalTasks = reader.GetInt32("TotalTasks"),
|
|
DraftTasks = reader.GetInt32("DraftTasks"),
|
|
PendingTasks = reader.GetInt32("PendingTasks"),
|
|
RunningTasks = reader.GetInt32("RunningTasks"),
|
|
PausedTasks = reader.GetInt32("PausedTasks"),
|
|
CompletedTasks = reader.GetInt32("CompletedTasks"),
|
|
FailedTasks = reader.GetInt32("FailedTasks"),
|
|
CancelledTasks = reader.GetInt32("CancelledTasks"),
|
|
TodayCompletedTasks = reader.GetInt32("TodayCompletedTasks"),
|
|
ThisWeekCompletedTasks = reader.GetInt32("ThisWeekCompletedTasks"),
|
|
ThisMonthCompletedTasks = reader.GetInt32("ThisMonthCompletedTasks"),
|
|
AverageTrainingMinutes = reader.IsDBNull("AverageTrainingMinutes") ? 0 : reader.GetDouble("AverageTrainingMinutes")
|
|
};
|
|
|
|
// 计算成功率
|
|
var totalFinishedTasks = statistics.CompletedTasks + statistics.FailedTasks + statistics.CancelledTasks;
|
|
statistics.SuccessRate = totalFinishedTasks > 0
|
|
? (double)statistics.CompletedTasks / totalFinishedTasks * 100
|
|
: 0;
|
|
|
|
return Result<TrainingTaskStatistics>.Success(statistics);
|
|
}
|
|
|
|
return Result<TrainingTaskStatistics>.Success(new TrainingTaskStatistics());
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "获取训练任务统计信息失败");
|
|
return Result<TrainingTaskStatistics>.Fail("TRAINING_TASK_GET_STATISTICS_FAILED", "获取统计信息失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<TrainingTaskModel>> UpdateStatusAsync(Guid id, TrainingTaskStatus status, string operatedBy)
|
|
{
|
|
try
|
|
{
|
|
const string sql = @"
|
|
UPDATE training_tasks SET
|
|
Status = @Status, UpdatedAtUtc = @UpdatedAtUtc, UpdatedBy = @UpdatedBy
|
|
WHERE Id = @Id";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
command.Parameters.AddWithValue("@Id", id);
|
|
command.Parameters.AddWithValue("@Status", (int)status);
|
|
command.Parameters.AddWithValue("@UpdatedAtUtc", DateTime.UtcNow);
|
|
command.Parameters.AddWithValue("@UpdatedBy", operatedBy);
|
|
|
|
var rowsAffected = await command.ExecuteNonQueryAsync();
|
|
if (rowsAffected == 0)
|
|
{
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_NOT_FOUND", "训练任务不存在");
|
|
}
|
|
|
|
// 返回更新后的任务
|
|
return await GetByIdAsync(id) switch
|
|
{
|
|
{ Succeeded: true, Data: not null } => Result<TrainingTaskModel>.Success((TrainingTaskModel)((Result<TrainingTaskModel?>)await GetByIdAsync(id)).Data),
|
|
_ => Result<TrainingTaskModel>.Fail("TRAINING_TASK_GET_AFTER_UPDATE_FAILED", "更新后获取任务失败")
|
|
};
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "更新训练任务状态失败: {TaskId}", id);
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_UPDATE_STATUS_FAILED", "更新任务状态失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<TrainingTaskModel>> UpdateProgressAsync(Guid id, int currentEpoch, double currentLoss, double currentMap, double progress)
|
|
{
|
|
try
|
|
{
|
|
const string sql = @"
|
|
UPDATE training_tasks SET
|
|
CurrentEpoch = @CurrentEpoch, CurrentLoss = @CurrentLoss, CurrentMap = @CurrentMap, Progress = @Progress,
|
|
UpdatedAtUtc = @UpdatedAtUtc
|
|
WHERE Id = @Id";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
command.Parameters.AddWithValue("@Id", id);
|
|
command.Parameters.AddWithValue("@CurrentEpoch", currentEpoch);
|
|
command.Parameters.AddWithValue("@CurrentLoss", currentLoss);
|
|
command.Parameters.AddWithValue("@CurrentMap", currentMap);
|
|
command.Parameters.AddWithValue("@Progress", progress);
|
|
command.Parameters.AddWithValue("@UpdatedAtUtc", DateTime.UtcNow);
|
|
|
|
var rowsAffected = await command.ExecuteNonQueryAsync();
|
|
if (rowsAffected == 0)
|
|
{
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_NOT_FOUND", "训练任务不存在");
|
|
}
|
|
|
|
// 返回更新后的任务
|
|
return await GetByIdAsync(id) switch
|
|
{
|
|
{ Succeeded: true, Data: not null } => Result<TrainingTaskModel>.Success((TrainingTaskModel)((Result<TrainingTaskModel?>)await GetByIdAsync(id)).Data),
|
|
_ => Result<TrainingTaskModel>.Fail("TRAINING_TASK_GET_AFTER_UPDATE_FAILED", "更新后获取任务失败")
|
|
};
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "更新训练任务进度失败: {TaskId}", id);
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_UPDATE_PROGRESS_FAILED", "更新任务进度失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<TrainingTaskModel>> AssignTaskAsync(Guid id, Guid assignedToId, string assignedToName, string assignedBy)
|
|
{
|
|
try
|
|
{
|
|
const string sql = @"
|
|
UPDATE training_tasks SET
|
|
AssignedToId = @AssignedToId, AssignedToName = @AssignedToName, AssignedBy = @AssignedBy,
|
|
UpdatedAtUtc = @UpdatedAtUtc, UpdatedBy = @UpdatedBy
|
|
WHERE Id = @Id";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
command.Parameters.AddWithValue("@Id", id);
|
|
command.Parameters.AddWithValue("@AssignedToId", assignedToId);
|
|
command.Parameters.AddWithValue("@AssignedToName", assignedToName);
|
|
command.Parameters.AddWithValue("@AssignedBy", assignedBy);
|
|
command.Parameters.AddWithValue("@UpdatedAtUtc", DateTime.UtcNow);
|
|
command.Parameters.AddWithValue("@UpdatedBy", assignedBy);
|
|
|
|
var rowsAffected = await command.ExecuteNonQueryAsync();
|
|
if (rowsAffected == 0)
|
|
{
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_NOT_FOUND", "训练任务不存在");
|
|
}
|
|
|
|
// 返回更新后的任务
|
|
return await GetByIdAsync(id) switch
|
|
{
|
|
{ Succeeded: true, Data: not null } => Result<TrainingTaskModel>.Success((TrainingTaskModel)((Result<TrainingTaskModel?>)await GetByIdAsync(id)).Data),
|
|
_ => Result<TrainingTaskModel>.Fail("TRAINING_TASK_GET_AFTER_UPDATE_FAILED", "更新后获取任务失败")
|
|
};
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "分配训练任务失败: {TaskId}", id);
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_ASSIGN_FAILED", "分配任务失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<TrainingTaskModel>> CompleteTaskAsync(Guid id, string outputModelPath, double finalLoss, double finalMap, string completedBy)
|
|
{
|
|
try
|
|
{
|
|
const string sql = @"
|
|
UPDATE training_tasks SET
|
|
Status = @Status, CompletedAtUtc = @CompletedAtUtc, OutputModelPath = @OutputModelPath,
|
|
FinalLoss = @FinalLoss, FinalMap = @FinalMap, Progress = 100,
|
|
UpdatedAtUtc = @UpdatedAtUtc, UpdatedBy = @UpdatedBy
|
|
WHERE Id = @Id";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
command.Parameters.AddWithValue("@Id", id);
|
|
command.Parameters.AddWithValue("@Status", (int)TrainingTaskStatus.Completed);
|
|
command.Parameters.AddWithValue("@CompletedAtUtc", DateTime.UtcNow);
|
|
command.Parameters.AddWithValue("@OutputModelPath", outputModelPath);
|
|
command.Parameters.AddWithValue("@FinalLoss", finalLoss);
|
|
command.Parameters.AddWithValue("@FinalMap", finalMap);
|
|
command.Parameters.AddWithValue("@UpdatedAtUtc", DateTime.UtcNow);
|
|
command.Parameters.AddWithValue("@UpdatedBy", completedBy);
|
|
|
|
var rowsAffected = await command.ExecuteNonQueryAsync();
|
|
if (rowsAffected == 0)
|
|
{
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_NOT_FOUND", "训练任务不存在");
|
|
}
|
|
|
|
// 返回更新后的任务
|
|
return await GetByIdAsync(id) switch
|
|
{
|
|
{ Succeeded: true, Data: not null } => Result<TrainingTaskModel>.Success((TrainingTaskModel)((Result<TrainingTaskModel?>)await GetByIdAsync(id)).Data),
|
|
_ => Result<TrainingTaskModel>.Fail("TRAINING_TASK_GET_AFTER_UPDATE_FAILED", "更新后获取任务失败")
|
|
};
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "完成训练任务失败: {TaskId}", id);
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_COMPLETE_FAILED", "完成任务失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<TrainingTaskModel>> FailTaskAsync(Guid id, string errorMessage, string failedBy)
|
|
{
|
|
try
|
|
{
|
|
const string sql = @"
|
|
UPDATE training_tasks SET
|
|
Status = @Status, ErrorMessage = @ErrorMessage, CompletedAtUtc = @CompletedAtUtc,
|
|
UpdatedAtUtc = @UpdatedAtUtc, UpdatedBy = @UpdatedBy
|
|
WHERE Id = @Id";
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
command.Parameters.AddWithValue("@Id", id);
|
|
command.Parameters.AddWithValue("@Status", (int)TrainingTaskStatus.Failed);
|
|
command.Parameters.AddWithValue("@ErrorMessage", errorMessage);
|
|
command.Parameters.AddWithValue("@CompletedAtUtc", DateTime.UtcNow);
|
|
command.Parameters.AddWithValue("@UpdatedAtUtc", DateTime.UtcNow);
|
|
command.Parameters.AddWithValue("@UpdatedBy", failedBy);
|
|
|
|
var rowsAffected = await command.ExecuteNonQueryAsync();
|
|
if (rowsAffected == 0)
|
|
{
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_NOT_FOUND", "训练任务不存在");
|
|
}
|
|
|
|
// 返回更新后的任务
|
|
return await GetByIdAsync(id) switch
|
|
{
|
|
{ Succeeded: true, Data: not null } => Result<TrainingTaskModel>.Success((TrainingTaskModel)((Result<TrainingTaskModel?>)await GetByIdAsync(id)).Data),
|
|
_ => Result<TrainingTaskModel>.Fail("TRAINING_TASK_GET_AFTER_UPDATE_FAILED", "更新后获取任务失败")
|
|
};
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "失败训练任务失败: {TaskId}", id);
|
|
return Result<TrainingTaskModel>.Fail("TRAINING_TASK_FAIL_FAILED", "任务失败操作失败");
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<Result<bool>> NameExistsAsync(string name, Guid? excludeId = null)
|
|
{
|
|
try
|
|
{
|
|
var sql = "SELECT COUNT(1) FROM training_tasks WHERE Name = @Name";
|
|
if (excludeId.HasValue)
|
|
{
|
|
sql += " AND Id != @ExcludeId";
|
|
}
|
|
|
|
using var connection = new SqlConnection(_connectionString);
|
|
await connection.OpenAsync();
|
|
|
|
using var command = new SqlCommand(sql, connection);
|
|
command.Parameters.AddWithValue("@Name", name);
|
|
if (excludeId.HasValue)
|
|
{
|
|
command.Parameters.AddWithValue("@ExcludeId", excludeId.Value);
|
|
}
|
|
|
|
var exists = (int)await command.ExecuteScalarAsync() > 0;
|
|
return Result<bool>.Success(exists);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "检查训练任务名称是否存在失败: {Name}", name);
|
|
return Result<bool>.Fail("TRAINING_TASK_CHECK_NAME_FAILED", "检查任务名称失败");
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// 添加任务参数到SqlCommand。
|
|
/// </summary>
|
|
private static void AddTaskParameters(SqlCommand command, TrainingTaskModel task)
|
|
{
|
|
command.Parameters.AddWithValue("@Id", task.Id);
|
|
command.Parameters.AddWithValue("@Name", task.Name);
|
|
command.Parameters.AddWithValue("@Description", (object?)task.Description ?? DBNull.Value);
|
|
command.Parameters.AddWithValue("@DatasetId", task.DatasetId);
|
|
command.Parameters.AddWithValue("@ModelType", task.ModelType);
|
|
command.Parameters.AddWithValue("@TrainingParametersJson", task.TrainingParametersJson);
|
|
command.Parameters.AddWithValue("@ValidationParametersJson", task.ValidationParametersJson);
|
|
command.Parameters.AddWithValue("@Status", (int)task.Status);
|
|
command.Parameters.AddWithValue("@Priority", (int)task.Priority);
|
|
command.Parameters.AddWithValue("@AssignedToId", (object?)task.AssignedToId ?? DBNull.Value);
|
|
command.Parameters.AddWithValue("@AssignedToName", (object?)task.AssignedToName ?? DBNull.Value);
|
|
command.Parameters.AddWithValue("@CreatedBy", task.CreatedBy);
|
|
command.Parameters.AddWithValue("@CurrentEpoch", task.CurrentEpoch);
|
|
command.Parameters.AddWithValue("@TotalEpochs", task.TotalEpochs);
|
|
command.Parameters.AddWithValue("@CurrentLoss", task.CurrentLoss);
|
|
command.Parameters.AddWithValue("@CurrentMap", task.CurrentMap);
|
|
command.Parameters.AddWithValue("@Progress", task.Progress);
|
|
command.Parameters.AddWithValue("@ActualStartAtUtc", (object?)task.ActualStartAtUtc ?? DBNull.Value);
|
|
command.Parameters.AddWithValue("@ActualEndAtUtc", (object?)task.ActualEndAtUtc ?? DBNull.Value);
|
|
command.Parameters.AddWithValue("@OutputModelPath", (object?)task.OutputModelPath ?? DBNull.Value);
|
|
command.Parameters.AddWithValue("@BestLoss", task.BestLoss);
|
|
command.Parameters.AddWithValue("@BestMap", task.BestMap);
|
|
command.Parameters.AddWithValue("@ErrorMessage", (object?)task.ErrorMessage ?? DBNull.Value);
|
|
command.Parameters.AddWithValue("@CreatedAtUtc", task.CreatedAtUtc);
|
|
command.Parameters.AddWithValue("@UpdatedAtUtc", task.UpdatedAtUtc);
|
|
command.Parameters.AddWithValue("@CreatedBy", task.CreatedBy);
|
|
command.Parameters.AddWithValue("@UpdatedBy", task.UpdatedBy);
|
|
command.Parameters.AddWithValue("@Remark", (object?)task.Remark ?? DBNull.Value);
|
|
}
|
|
|
|
/// <summary>
|
|
/// 将DataReader映射到TrainingTaskModel。
|
|
/// </summary>
|
|
private static TrainingTaskModel MapReaderToTask(SqlDataReader reader)
|
|
{
|
|
return new TrainingTaskModel
|
|
{
|
|
Id = reader.GetGuid("Id"),
|
|
Name = reader.GetString("Name"),
|
|
Description = reader.IsDBNull("Description") ? null : reader.GetString("Description"),
|
|
DatasetId = reader.GetGuid("DatasetId"),
|
|
ModelType = reader.GetString("ModelType"),
|
|
TrainingParametersJson = reader.GetString("TrainingParametersJson"),
|
|
ValidationParametersJson = reader.GetString("ValidationParametersJson"),
|
|
Status = (TrainingTaskStatus)reader.GetInt32("Status"),
|
|
Priority = (TrainingPriority)reader.GetInt32("Priority"),
|
|
AssignedToId = reader.IsDBNull("AssignedToId") ? null : reader.GetGuid("AssignedToId"),
|
|
AssignedToName = reader.IsDBNull("AssignedToName") ? null : reader.GetString("AssignedToName"),
|
|
CreatedBy = reader.GetString("CreatedBy"),
|
|
CurrentEpoch = reader.GetInt32("CurrentEpoch"),
|
|
TotalEpochs = reader.GetInt32("TotalEpochs"),
|
|
CurrentLoss = reader.GetDouble("CurrentLoss"),
|
|
CurrentMap = reader.GetDouble("CurrentMap"),
|
|
Progress = reader.GetDouble("Progress"),
|
|
ActualStartAtUtc = reader.IsDBNull("ActualStartAtUtc") ? null : reader.GetDateTime("ActualStartAtUtc"),
|
|
ActualEndAtUtc = reader.IsDBNull("ActualEndAtUtc") ? null : reader.GetDateTime("ActualEndAtUtc"),
|
|
OutputModelPath = reader.IsDBNull("OutputModelPath") ? null : reader.GetString("OutputModelPath"),
|
|
BestLoss = reader.GetDouble("BestLoss"),
|
|
BestMap = reader.GetDouble("BestMap"),
|
|
ErrorMessage = reader.IsDBNull("ErrorMessage") ? null : reader.GetString("ErrorMessage"),
|
|
CreatedAtUtc = reader.GetDateTime("CreatedAtUtc"),
|
|
UpdatedAtUtc = reader.GetDateTime("UpdatedAtUtc"),
|
|
UpdatedBy = reader.GetString("UpdatedBy"),
|
|
Remark = reader.IsDBNull("Remark") ? null : reader.GetString("Remark")
|
|
};
|
|
}
|
|
}
|