Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 96 additions & 39 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,26 +866,12 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named
}
var res driver.Result
if s.(*SQLiteStmt).s != nil {
stmtArgs := make([]driver.NamedValue, 0, len(args))
na := s.NumInput()
if len(args)-start < na {
s.Close()
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
}
// consume the number of arguments used in the current
// statement and append all named arguments not
// contained therein
if na > 0 {
stmtArgs = append(stmtArgs, args[start:start+na]...)
for i := range args {
if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i])
}
}
for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1
}
}
stmtArgs := stmtArgs(args, start, na)
res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip {
s.Close()
Expand Down Expand Up @@ -921,7 +907,6 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
start := 0
for {
stmtArgs := make([]driver.NamedValue, 0, len(args))
s, err := c.prepare(ctx, query)
if err != nil {
return nil, err
Expand All @@ -932,18 +917,7 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
s.Close()
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
}
// consume the number of arguments used in the current
// statement and append all named arguments not contained
// therein
stmtArgs = append(stmtArgs, args[start:start+na]...)
for i := range args {
if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i])
}
}
for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1
}
stmtArgs := stmtArgs(args, start, na)
rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip {
s.Close()
Expand Down Expand Up @@ -1957,6 +1931,36 @@ func (s *SQLiteStmt) NumInput() int {

var placeHolder = []byte{0}

func stmtArgs(args []driver.NamedValue, start, na int) []driver.NamedValue {
if na == 0 {
return nil
}

end := start + na
hasNamedOutside := false
for i := range args {
if args[i].Name != "" && (i < start || i >= end) {
hasNamedOutside = true
break
}
}
if start == 0 && !hasNamedOutside {
return args[start:end]
}

stmtArgs := make([]driver.NamedValue, 0, len(args))
stmtArgs = append(stmtArgs, args[start:end]...)
for i := range args {
if args[i].Name != "" && (i < start || i >= end) {
stmtArgs = append(stmtArgs, args[i])
}
}
for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1
}
return stmtArgs
}

func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
Expand All @@ -1965,26 +1969,79 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error {

C.sqlite3_clear_bindings(s.s)

hasNamed := false
for i := range args {
if args[i].Name != "" {
hasNamed = true
break
}
}

if !hasNamed {
for _, arg := range args {
n := C.int(arg.Ordinal)
switch v := arg.Value.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
case string:
if len(v) == 0 {
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
} else {
b := []byte(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case bool:
if v {
rv = C.sqlite3_bind_int(s.s, n, 1)
} else {
rv = C.sqlite3_bind_int(s.s, n, 0)
}
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
if v == nil {
rv = C.sqlite3_bind_null(s.s, n)
} else {
ln := len(v)
if ln == 0 {
v = placeHolder
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
}
case time.Time:
b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
if rv != C.SQLITE_OK {
return s.c.lastError()
}
}
return nil
}

bindIndices := make([][3]int, len(args))
prefixes := []string{":", "@", "$"}
prefixes := [3]string{":", "@", "$"}
for i, v := range args {
bindIndices[i][0] = v.Ordinal
if v.Name != "" {
for j := range prefixes {
cname := C.CString(prefixes[j] + v.Name)
bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname))
C.free(unsafe.Pointer(cname))
}
args[i].Ordinal = bindIndices[i][0]
if v.Name == "" {
continue
}
for j := range prefixes {
cname := C.CString(prefixes[j] + v.Name)
bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname))
C.free(unsafe.Pointer(cname))
}
args[i].Ordinal = bindIndices[i][0]
}

for i, arg := range args {
for j := range bindIndices[i] {
if bindIndices[i][j] == 0 {
for _, idx := range bindIndices[i] {
if idx == 0 {
continue
}
n := C.int(bindIndices[i][j])
n := C.int(idx)
switch v := arg.Value.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
Expand Down
Loading