diff --git a/sqlite3.go b/sqlite3.go index 76b0f232..3ef78d59 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -866,26 +866,12 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named } var res driver.Result if s.(*SQLiteStmt).s != nil { - stmtArgs := make([]driver.NamedValue, 0, len(args)) na := s.NumInput() if len(args)-start < na { s.Close() return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } - // consume the number of arguments used in the current - // statement and append all named arguments not - // contained therein - if na > 0 { - stmtArgs = append(stmtArgs, args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 - } - } + stmtArgs := stmtArgs(args, start, na) res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { s.Close() @@ -921,7 +907,6 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { start := 0 for { - stmtArgs := make([]driver.NamedValue, 0, len(args)) s, err := c.prepare(ctx, query) if err != nil { return nil, err @@ -932,18 +917,7 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name s.Close() return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start) } - // consume the number of arguments used in the current - // statement and append all named arguments not contained - // therein - stmtArgs = append(stmtArgs, args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 - } + stmtArgs := stmtArgs(args, start, na) rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { s.Close() @@ -1957,6 +1931,36 @@ func (s *SQLiteStmt) NumInput() int { var placeHolder = []byte{0} +func stmtArgs(args []driver.NamedValue, start, na int) []driver.NamedValue { + if na == 0 { + return nil + } + + end := start + na + hasNamedOutside := false + for i := range args { + if args[i].Name != "" && (i < start || i >= end) { + hasNamedOutside = true + break + } + } + if start == 0 && !hasNamedOutside { + return args[start:end] + } + + stmtArgs := make([]driver.NamedValue, 0, len(args)) + stmtArgs = append(stmtArgs, args[start:end]...) + for i := range args { + if args[i].Name != "" && (i < start || i >= end) { + stmtArgs = append(stmtArgs, args[i]) + } + } + for i := range stmtArgs { + stmtArgs[i].Ordinal = i + 1 + } + return stmtArgs +} + func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { @@ -1965,26 +1969,79 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { C.sqlite3_clear_bindings(s.s) + hasNamed := false + for i := range args { + if args[i].Name != "" { + hasNamed = true + break + } + } + + if !hasNamed { + for _, arg := range args { + n := C.int(arg.Ordinal) + switch v := arg.Value.(type) { + case nil: + rv = C.sqlite3_bind_null(s.s, n) + case string: + if len(v) == 0 { + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0)) + } else { + b := []byte(v) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + } + case int64: + rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) + case bool: + if v { + rv = C.sqlite3_bind_int(s.s, n, 1) + } else { + rv = C.sqlite3_bind_int(s.s, n, 0) + } + case float64: + rv = C.sqlite3_bind_double(s.s, n, C.double(v)) + case []byte: + if v == nil { + rv = C.sqlite3_bind_null(s.s, n) + } else { + ln := len(v) + if ln == 0 { + v = placeHolder + } + rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) + } + case time.Time: + b := []byte(v.Format(SQLiteTimestampFormats[0])) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + } + if rv != C.SQLITE_OK { + return s.c.lastError() + } + } + return nil + } + bindIndices := make([][3]int, len(args)) - prefixes := []string{":", "@", "$"} + prefixes := [3]string{":", "@", "$"} for i, v := range args { bindIndices[i][0] = v.Ordinal - if v.Name != "" { - for j := range prefixes { - cname := C.CString(prefixes[j] + v.Name) - bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname)) - C.free(unsafe.Pointer(cname)) - } - args[i].Ordinal = bindIndices[i][0] + if v.Name == "" { + continue + } + for j := range prefixes { + cname := C.CString(prefixes[j] + v.Name) + bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname)) + C.free(unsafe.Pointer(cname)) } + args[i].Ordinal = bindIndices[i][0] } for i, arg := range args { - for j := range bindIndices[i] { - if bindIndices[i][j] == 0 { + for _, idx := range bindIndices[i] { + if idx == 0 { continue } - n := C.int(bindIndices[i][j]) + n := C.int(idx) switch v := arg.Value.(type) { case nil: rv = C.sqlite3_bind_null(s.s, n)