Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ func addOutputFlag(flags *pflag.FlagSet) {
"directory where downloaded files are put")
}

func addPrivateFlag(flags *pflag.FlagSet) {
flags.BoolP(
"private", "P", false,
"include private ones (requires private scanning privileges)")
}

func addFilterFlag(flags *pflag.FlagSet) {
flags.StringP(
"filter", "f", "",
Expand Down
2 changes: 1 addition & 1 deletion cmd/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func NewCollectionCmd() *cobra.Command {
cmd.AddCommand(NewCollectionDeleteCmd())
cmd.AddCommand(NewCollectionRemoveItemsCmd())

addRelationshipCmds(cmd, "collections", "collection", "[collection]")
addRelationshipCmds(cmd, "collections", "collection", "[collection]", false)
addThreadsFlag(cmd.Flags())
addIncludeExcludeFlags(cmd.Flags())
addIDOnlyFlag(cmd.Flags())
Expand Down
2 changes: 1 addition & 1 deletion cmd/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func NewDomainCmd() *cobra.Command {
},
}

addRelationshipCmds(cmd, "domains", "domain", "[domain]")
addRelationshipCmds(cmd, "domains", "domain", "[domain]", false)

addThreadsFlag(cmd.Flags())
addIncludeExcludeFlags(cmd.Flags())
Expand Down
21 changes: 15 additions & 6 deletions cmd/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
package cmd

import (
"github.com/VirusTotal/vt-cli/utils"
"regexp"

"github.com/VirusTotal/vt-cli/utils"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)

var fileCmdHelp = `Get information about one or more files.
Expand Down Expand Up @@ -50,18 +51,26 @@ func NewFileCmd() *cobra.Command {
if err != nil {
return err
}
return p.GetAndPrintObjects(
"files/%s",
utils.StringReaderFromCmdArgs(args),
re)
if viper.GetBool("private") {
return p.GetAndPrintObjectsWithFallback(
[]string{"files/%s", "private/files/%s"},
utils.StringReaderFromCmdArgs(args),
re)
} else {
return p.GetAndPrintObjects(
"files/%s",
utils.StringReaderFromCmdArgs(args),
re)
}
},
}

addRelationshipCmds(cmd, "files", "file", "[hash]")
addRelationshipCmds(cmd, "files", "file", "[hash]", true)

addThreadsFlag(cmd.Flags())
addIncludeExcludeFlags(cmd.Flags())
addIDOnlyFlag(cmd.Flags())
addPrivateFlag(cmd.Flags())

return cmd
}
5 changes: 3 additions & 2 deletions cmd/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
package cmd

import (
"github.com/VirusTotal/vt-cli/utils"
"regexp"

"github.com/VirusTotal/vt-cli/utils"

"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -55,7 +56,7 @@ func NewIPCmd() *cobra.Command {
},
}

addRelationshipCmds(cmd, "ip_addresses", "ip_address", "[ip]")
addRelationshipCmds(cmd, "ip_addresses", "ip_address", "[ip]", false)

addThreadsFlag(cmd.Flags())
addIncludeExcludeFlags(cmd.Flags())
Expand Down
2 changes: 1 addition & 1 deletion cmd/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ func NewMonitorCmd() *cobra.Command {
cmd.AddCommand(NewMonitorItemsSetDetailsCmd())
cmd.AddCommand(NewMonitorItemsDeleteDetailsCmd())

addRelationshipCmds(cmd, "monitor/items", "monitor_item", "[monitor_id]")
addRelationshipCmds(cmd, "monitor/items", "monitor_item", "[monitor_id]", false)

return cmd
}
2 changes: 1 addition & 1 deletion cmd/monitorpartner.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func NewMonitorPartnerCmd() *cobra.Command {
cmd.AddCommand(NewMonitorPartnerHashesListCmd())
cmd.AddCommand(NewMonitorPartnerHashDownloadCmd())

addRelationshipCmds(cmd, "monitor_partner/hashes", "monitor_hash", "[sha256]")
addRelationshipCmds(cmd, "monitor_partner/hashes", "monitor_hash", "[sha256]", false)

return cmd
}
28 changes: 22 additions & 6 deletions cmd/relationship.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ import (
"encoding/base64"
"encoding/gob"
"fmt"
"github.com/VirusTotal/vt-cli/utils"
"os"
"path"
"sync"

"github.com/VirusTotal/vt-cli/utils"

vt "github.com/VirusTotal/vt-go"
homedir "github.com/mitchellh/go-homedir"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -69,7 +70,7 @@ func getRelatedObjects(collection, objectID, relationship string, limit int) ([]
}

// NewRelationshipCmd returns a new instance of the 'relationship' command.
func NewRelationshipCmd(collection, relationship, use, description string) *cobra.Command {
func NewRelationshipCmd(collection, relationship, use, description string, private_flag bool) *cobra.Command {
cmd := &cobra.Command{
Args: cobra.ExactArgs(1),
Use: fmt.Sprintf("%s %s", relationship, use),
Expand All @@ -85,6 +86,9 @@ func NewRelationshipCmd(collection, relationship, use, description string) *cobr
if err != nil {
return err
}
if viper.GetBool("private") {
collection = "private/" + collection
}
url := vt.URL("%s/%s/%s", collection, objectID, relationship)
return p.PrintCollection(url)
},
Expand All @@ -95,18 +99,26 @@ func NewRelationshipCmd(collection, relationship, use, description string) *cobr
addLimitFlag(cmd.Flags())
addCursorFlag(cmd.Flags())

if private_flag {
addPrivateFlag(cmd.Flags())
}

return cmd
}

// NewRelationshipsCmd returns a new instance of the 'relationships' command.
func NewRelationshipsCmd(collection, objectType, use string) *cobra.Command {
func NewRelationshipsCmd(collection, objectType, use string, private_flag bool) *cobra.Command {
cmd := &cobra.Command{
Use: fmt.Sprintf("relationships %s", use),
Short: "Get all relationships.",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
var wg sync.WaitGroup
var sm sync.Map
if viper.GetBool("private") {
objectType = "private_" + objectType
collection = "private/" + collection
}
for _, r := range objectRelationshipsMap[objectType] {
wg.Add(1)
go func(relationshipName string) {
Expand Down Expand Up @@ -148,13 +160,17 @@ func NewRelationshipsCmd(collection, objectType, use string) *cobra.Command {
addIncludeExcludeFlags(cmd.Flags())
addLimitFlag(cmd.Flags())

if private_flag {
addPrivateFlag(cmd.Flags())
}

return cmd
}

func addRelationshipCmds(cmd *cobra.Command, collection, objectType, use string) {
func addRelationshipCmds(cmd *cobra.Command, collection, objectType, use string, private_flag bool) {
relationships := objectRelationshipsMap[objectType]
for _, r := range relationships {
cmd.AddCommand(NewRelationshipCmd(collection, r.Name, use, r.Description))
cmd.AddCommand(NewRelationshipCmd(collection, r.Name, use, r.Description, private_flag))
}
cmd.AddCommand(NewRelationshipsCmd(collection, objectType, use))
cmd.AddCommand(NewRelationshipsCmd(collection, objectType, use, private_flag))
}
2 changes: 1 addition & 1 deletion cmd/threat_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func NewThreatProfileCmd() *cobra.Command {
},
}

addRelationshipCmds(cmd, "threat_profiles", "threat_profile", "[id]")
addRelationshipCmds(cmd, "threat_profiles", "threat_profile", "[id]", false)
addThreadsFlag(cmd.Flags())
addIncludeExcludeFlags(cmd.Flags())
addIDOnlyFlag(cmd.Flags())
Expand Down
23 changes: 18 additions & 5 deletions cmd/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ package cmd

import (
"encoding/base64"
"regexp"

"github.com/VirusTotal/vt-cli/utils"
"github.com/spf13/cobra"
"regexp"
"github.com/spf13/viper"
)

var urlCmdHelp = `Get information about one or more URLs.
Expand All @@ -35,7 +37,6 @@ var urlCmdExample = ` vt url https://www.virustotal.com
vt url f1177df4692356280844e1d5af67cc4a9eccecf77aa61c229d483b7082c70a8e
cat list_of_urls | vt url -`


// Regular expressions used for validating a URL identifier.
var urlID = regexp.MustCompile(`[0-9a-fA-F]{64}`)

Expand All @@ -55,7 +56,7 @@ func NewURLCmd() *cobra.Command {
}
r := utils.NewMappedStringReader(
utils.StringReaderFromCmdArgs(args),
func (url string) string {
func(url string) string {
if urlID.MatchString(url) {
// The user provided a URL identifier as returned by
// VirusTotal's API, which consists in the URL's SHA-256.
Expand All @@ -66,15 +67,27 @@ func NewURLCmd() *cobra.Command {
// encoded as base64 before being used.
return base64.RawURLEncoding.EncodeToString([]byte(url))
})
return p.GetAndPrintObjects("urls/%s", r, nil)

if viper.GetBool("private") {
return p.GetAndPrintObjectsWithFallback(
[]string{"urls/%s", "private/urls/%s"},
r,
nil)
} else {
return p.GetAndPrintObjects(
"urls/%s",
r,
nil)
}
},
}

addRelationshipCmds(cmd, "urls", "url", "[url]")
addRelationshipCmds(cmd, "urls", "url", "[url]", true)

addThreadsFlag(cmd.Flags())
addIncludeExcludeFlags(cmd.Flags())
addIDOnlyFlag(cmd.Flags())
addPrivateFlag(cmd.Flags())

return cmd
}
26 changes: 21 additions & 5 deletions utils/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ func NewAPIClient(agent string) (*APIClient, error) {
// must contain a %s placeholder that will be replaced with items from the args
// slice. The objects are put into the outCh as they are retrieved.
func (c *APIClient) RetrieveObjects(endpoint string, args []string, outCh chan *vt.Object, errCh chan error) error {
return c.RetrieveObjectsWithFallback([]string{endpoint}, args, outCh, errCh)
}

// RetrieveObjectsWithFallback retrieves objects from the specified endpoints. It
// tries the endpoints in the order they are provided until one of them returns
// the object. The endpoint strings must contain a %s placeholder that will be
// replaced with items from the args slice. The objects are put into the outCh
// as they are retrieved.
func (c *APIClient) RetrieveObjectsWithFallback(endpoints []string, args []string, outCh chan *vt.Object, errCh chan error) error {

// Make sure outCh and errCh are closed
defer close(outCh)
Expand Down Expand Up @@ -75,17 +84,24 @@ func (c *APIClient) RetrieveObjects(endpoint string, args []string, outCh chan *
getWg.Add(1)
go func(order int, arg string) {
throttler <- nil
obj, err := c.GetObject(vt.URL(endpoint, arg))
if err == nil {
objCh <- PQueueNode{Priority: order, Data: obj}
} else {
var obj *vt.Object
var err error
for _, endpoint := range endpoints {
obj, err = c.GetObject(vt.URL(endpoint, arg))
if err == nil {
objCh <- PQueueNode{Priority: order, Data: obj}
break
}
if apiErr, ok := err.(vt.Error); ok && apiErr.Code == "NotFoundError" {
objCh <- PQueueNode{Priority: order, Data: err}
// Try the next endpoint
} else {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
if err != nil {
objCh <- PQueueNode{Priority: order, Data: err}
}
getWg.Done()
<-throttler
}(order, arg)
Expand Down
12 changes: 11 additions & 1 deletion utils/printer.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ func (p *Printer) PrintObject(obj *vt.Object) error {
// read from stdin one per line. If argRe is non-nil, only args that match the
// regular expression are used and the rest are discarded.
func (p *Printer) GetAndPrintObjects(endpoint string, r StringReader, argRe *regexp.Regexp) error {
return p.GetAndPrintObjectsWithFallback([]string{endpoint}, r, argRe)
}

// GetAndPrintObjectsWithFallback retrieves objects from the specified endpoints and
// prints them. The function tries the endpoints in the order they are provided
// until one of them returns the object. The endpoint must contain a %s placeholder
// that will be replaced with items from the args slice. If args contains a single
// "-" string, the args are read from stdin one per line. If argRe is non-nil, only
// args that match the regular expression are used and the rest are discarded.
func (p *Printer) GetAndPrintObjectsWithFallback(endpoints []string, r StringReader, argRe *regexp.Regexp) error {
if argRe != nil {
r = NewFilteredStringReader(r, argRe)
}
Expand All @@ -171,7 +181,7 @@ func (p *Printer) GetAndPrintObjects(endpoint string, r StringReader, argRe *reg
objectsCh := make(chan *vt.Object)
errorsCh := make(chan error, len(filteredArgs))

go p.client.RetrieveObjects(endpoint, filteredArgs, objectsCh, errorsCh)
go p.client.RetrieveObjectsWithFallback(endpoints, filteredArgs, objectsCh, errorsCh)

if viper.GetBool("identifiers-only") {
var objectIds []string
Expand Down
Loading