Skip to content

Commit

Permalink
feat: refactor CreateItems(...) to batch inserts
Browse files Browse the repository at this point in the history
  • Loading branch information
jimlambrt committed Jul 13, 2024
1 parent c8a28e0 commit 8be13f4
Show file tree
Hide file tree
Showing 8 changed files with 450 additions and 85 deletions.
186 changes: 144 additions & 42 deletions coverage/coverage.html

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions coverage/coverage.log
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
1693014148,90.9
1720903770,90.8
2 changes: 1 addition & 1 deletion coverage/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
147 changes: 117 additions & 30 deletions create.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ const (

// DeleteOp is a delete operation
DeleteOp OpType = 3

// DefaultBatchSize is the default batch size for bulk operations like
// CreateItems. This value is used if the caller does not specify a size
// using the WithBatchSize(...) option.
DefaultBatchSize = 100
)

// VetForWriter provides an interface that Create and Update can use to vet the
Expand Down Expand Up @@ -184,58 +189,140 @@ func (rw *RW) Create(ctx context.Context, i interface{}, opt ...Option) error {
}

// CreateItems will create multiple items of the same type. Supported options:
// WithDebug, WithBeforeWrite, WithAfterWrite, WithReturnRowsAffected,
// OnConflict, WithVersion, WithTable, and WithWhere. WithLookup is not a supported option.
func (rw *RW) CreateItems(ctx context.Context, createItems []interface{}, opt ...Option) error {
// WithBatchSize, WithDebug, WithBeforeWrite, WithAfterWrite,
// WithReturnRowsAffected, OnConflict, WithVersion, WithTable, and WithWhere.
// WithLookup is not a supported option.
func (rw *RW) CreateItems(ctx context.Context, createItems interface{}, opt ...Option) error {
const op = "dbw.CreateItems"
if rw.underlying == nil {
switch {
case rw.underlying == nil:
return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
case isNil(createItems):
return fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter)
}
if len(createItems) == 0 {
return fmt.Errorf("%s: missing interfaces: %w", op, ErrInvalidParameter)
valCreateItems := reflect.ValueOf(createItems)
switch {
case valCreateItems.Kind() != reflect.Slice:
return fmt.Errorf("%s: not a slice: %w", op, ErrInvalidParameter)
case valCreateItems.Len() == 0:
return fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter)
}
if err := raiseErrorOnHooks(createItems); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
opts := GetOpts(opt...)
if opts.WithLookup {
switch {
case opts.WithLookup:
return fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter)
}
// verify that createItems are all the same type.
var foundType reflect.Type
for i, v := range createItems {
for i := 0; i < valCreateItems.Len(); i++ {
// verify that createItems are all the same type and do some bits on each item
if i == 0 {
foundType = reflect.TypeOf(v)
foundType = reflect.TypeOf(valCreateItems.Index(i).Interface())
}
currentType := reflect.TypeOf(valCreateItems.Index(i).Interface())
if currentType == nil {
return fmt.Errorf("%s: unable to determine type of item %d: %w", op, i, ErrInvalidParameter)
}
currentType := reflect.TypeOf(v)
if foundType != currentType {
return fmt.Errorf("%s: create items contains disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter)
}
}
if opts.WithBeforeWrite != nil {
if err := opts.WithBeforeWrite(createItems); err != nil {
return fmt.Errorf("%s: error before write: %w", op, err)

// these fields should be nil, since they are not writeable and we want the
// db to manage them
setFieldsToNil(valCreateItems.Index(i).Interface(), NonCreatableFields())

// vet each item
if !opts.WithSkipVetForWrite {
if vetter, ok := valCreateItems.Index(i).Interface().(VetForWriter); ok {
if err := vetter.VetForWrite(ctx, rw, CreateOp); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
}
}

if opts.WithBeforeWrite != nil {
if err := opts.WithBeforeWrite(valCreateItems.Index(i).Interface()); err != nil {
return fmt.Errorf("%s: error before write: %w", op, err)
}
}
}
var rowsAffected int64
for _, item := range createItems {
if err := rw.Create(ctx, item,
WithOnConflict(opts.WithOnConflict),
WithReturnRowsAffected(&rowsAffected),
WithDebug(opts.WithDebug),
WithVersion(opts.WithVersion),
WithWhere(opts.WithWhereClause, opts.WithWhereClauseArgs...),
WithTable(opts.WithTable),
); err != nil {
return fmt.Errorf("%s: %w", op, err)

db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithOnConflict != nil {
c := clause.OnConflict{}
switch opts.WithOnConflict.Target.(type) {
case Constraint:
c.OnConstraint = string(opts.WithOnConflict.Target.(Constraint))
case Columns:
columns := make([]clause.Column, 0, len(opts.WithOnConflict.Target.(Columns)))
for _, name := range opts.WithOnConflict.Target.(Columns) {
columns = append(columns, clause.Column{Name: name})
}
c.Columns = columns
default:
return fmt.Errorf("%s: invalid conflict target %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Target), ErrInvalidParameter)
}

switch opts.WithOnConflict.Action.(type) {
case DoNothing:
c.DoNothing = true
case UpdateAll:
c.UpdateAll = true
case []ColumnValue:
updates := opts.WithOnConflict.Action.([]ColumnValue)
set := make(clause.Set, 0, len(updates))
for _, s := range updates {
// make sure it's not one of the std immutable columns
if contains([]string{"createtime", "publicid"}, strings.ToLower(s.Column)) {
return fmt.Errorf("%s: cannot do update on conflict for column %s: %w", op, s.Column, ErrInvalidParameter)
}
switch sv := s.Value.(type) {
case Column:
set = append(set, sv.toAssignment(s.Column))
case ExprValue:
set = append(set, sv.toAssignment(s.Column))
default:
set = append(set, rawAssignment(s.Column, s.Value))
}
}
c.DoUpdates = set
default:
return fmt.Errorf("%s: invalid conflict action %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Action), ErrInvalidParameter)
}
if opts.WithVersion != nil || opts.WithWhereClause != "" {
// this is a bit of a hack, but we need to pass in one of the items
// to get the where clause since we need to get the gorm Model and
// Parse the gorm statement to build the where clause
where, args, err := rw.whereClausesFromOpts(ctx, valCreateItems.Index(0).Interface(), opts)
if err != nil {
return fmt.Errorf("%s: %w", op, err)
}
whereConditions := db.Statement.BuildCondition(where, args...)
c.Where = clause.Where{Exprs: whereConditions}
}
db = db.Clauses(c)
}
if opts.WithDebug {
db = db.Debug()
}
if opts.WithTable != "" {
db = db.Table(opts.WithTable)
}

tx := db.CreateInBatches(createItems, opts.WithBatchSize)
if tx.Error != nil {
return fmt.Errorf("%s: create failed: %w", op, tx.Error)
}
if opts.WithRowsAffected != nil {
*opts.WithRowsAffected = rowsAffected
*opts.WithRowsAffected = tx.RowsAffected
}
if opts.WithAfterWrite != nil {
if err := opts.WithAfterWrite(createItems, int(rowsAffected)); err != nil {
return fmt.Errorf("%s: error after write: %w", op, err)
if tx.RowsAffected > 0 && opts.WithAfterWrite != nil {
for i := 0; i < valCreateItems.Len(); i++ {
if err := opts.WithAfterWrite(valCreateItems.Index(i).Interface(), int(tx.RowsAffected)); err != nil {
return fmt.Errorf("%s: error after write: %w", op, err)
}
}
}
return nil
Expand Down
Loading

0 comments on commit 8be13f4

Please sign in to comment.