Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add number field in trials table. #88

Merged
merged 1 commit into from
Mar 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions rdb/attrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
// See https://github.com/c-bata/goptuna/issues/34
// for the reason why we need following code.

// Caution "_number" in trial_system_attributes must not be encoded.

func encodeAttrValue(xr string) string {
return fmt.Sprintf("\"%s\"",
base64.StdEncoding.EncodeToString([]byte(xr)))
Expand Down
23 changes: 4 additions & 19 deletions rdb/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package rdb

import (
"errors"
"fmt"
"strconv"
"time"

"github.com/c-bata/goptuna"
Expand All @@ -21,13 +19,9 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) {

systemAttrs := make(map[string]string, len(trial.SystemAttributes))
for i := range trial.SystemAttributes {
if trial.SystemAttributes[i].Key == keyNumber {
systemAttrs[trial.SystemAttributes[i].Key] = trial.SystemAttributes[i].ValueJSON
} else {
systemAttrs[trial.SystemAttributes[i].Key], err = decodeAttrValue(trial.SystemAttributes[i].ValueJSON)
if err != nil {
return goptuna.FrozenTrial{}, err
}
systemAttrs[trial.SystemAttributes[i].Key], err = decodeAttrValue(trial.SystemAttributes[i].ValueJSON)
if err != nil {
return goptuna.FrozenTrial{}, err
}
}

Expand All @@ -52,15 +46,6 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) {
}
}

numberStr, ok := systemAttrs[keyNumber]
if !ok {
return goptuna.FrozenTrial{}, errors.New("number is not exist in system attrs")
}
number, err := strconv.Atoi(numberStr)
if err != nil {
return goptuna.FrozenTrial{}, fmt.Errorf("invalid trial number: %s", err)
}

state, err := toStateExternalRepresentation(trial.State)
if err != nil {
return goptuna.FrozenTrial{}, err
Expand All @@ -82,7 +67,7 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) {
return goptuna.FrozenTrial{
ID: trial.ID,
StudyID: trial.TrialReferStudy,
Number: number,
Number: trial.Number,
State: state,
Value: trial.Value,
IntermediateValues: intermediateValue,
Expand Down
1 change: 1 addition & 0 deletions rdb/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func (m studySystemAttributeModel) TableName() string {

type trialModel struct {
ID int `gorm:"column:trial_id;PRIMARY_KEY"`
Number int `gorm:"column:number"`
TrialReferStudy int `gorm:"column:study_id"`
State string `gorm:"column:state;NOT NULL"`
Value float64 `gorm:"column:value"`
Expand Down
30 changes: 9 additions & 21 deletions rdb/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package rdb

import (
"fmt"
"strconv"
"time"

"github.com/c-bata/goptuna"
Expand All @@ -12,8 +11,6 @@ import (

var _ goptuna.Storage = &Storage{}

const keyNumber = "_number"

// NewStorage returns new RDB storage.
func NewStorage(db *gorm.DB) *Storage {
return &Storage{
Expand Down Expand Up @@ -234,12 +231,9 @@ func (s *Storage) CreateNewTrial(studyID int) (int, error) {
return -1, err
}

// Set '_number' in trial_system_attributes.
err = tx.Create(&trialSystemAttributeModel{
SystemAttributeReferTrial: trial.ID,
Key: keyNumber,
ValueJSON: strconv.Itoa(number),
}).Error
err = tx.Model(&trialModel{}).
Where("trial_id = ?", trial.ID).
Update("number", number).Error
if err != nil {
tx.Rollback()
return -1, err
Expand Down Expand Up @@ -313,9 +307,6 @@ func (s *Storage) CloneTrial(studyID int, baseTrial goptuna.FrozenTrial) (int, e

// system attrs
for key := range baseTrial.SystemAttrs {
if key == "_number" {
continue
}
err := tx.Create(&trialSystemAttributeModel{
SystemAttributeReferTrial: trial.ID,
Key: key,
Expand Down Expand Up @@ -364,11 +355,10 @@ func (s *Storage) CloneTrial(studyID int, baseTrial goptuna.FrozenTrial) (int, e
tx.Rollback()
return -1, err
}
err = tx.Create(&trialSystemAttributeModel{
SystemAttributeReferTrial: trial.ID,
Key: keyNumber,
ValueJSON: strconv.Itoa(number),
}).Error

err = tx.Model(&trialModel{}).
Where("trial_id = ?", trial.ID).
Update("number", number).Error
if err != nil {
tx.Rollback()
return -1, err
Expand Down Expand Up @@ -579,13 +569,11 @@ func (s *Storage) SetTrialSystemAttr(trialID int, key string, value string) erro

// GetTrialNumberFromID returns the trial's number.
func (s *Storage) GetTrialNumberFromID(trialID int) (int, error) {
var attr trialSystemAttributeModel
err := s.db.First(&attr, "trial_id = ? AND key = ?", trialID, keyNumber).Error
trial, err := s.GetTrial(trialID)
if err != nil {
return -1, err
}
number, err := strconv.Atoi(attr.ValueJSON)
return number, err
return trial.Number, err
}

// GetTrialParam returns the internal parameter of the trial
Expand Down