From 8550f1a0bffa753075c5bb6da74d6e02a5121845 Mon Sep 17 00:00:00 2001 From: Adem Rosic <83852285+The-Inceptions@users.noreply.github.com> Date: Tue, 16 Jan 2024 11:41:15 -0500 Subject: [PATCH 1/2] Modified existing comments Modified existing comments in assetdb.go to align with the SQL query actually in use. --- assetdb.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/assetdb.go b/assetdb.go index 91a4335..bc421bd 100644 --- a/assetdb.go +++ b/assetdb.go @@ -110,13 +110,13 @@ func (as *AssetDB) RawQuery(sqlstr string, results interface{}) error { } // AssetQuery executes a query against the asset table of the db. -// For SQL databases, the query will start with "SELECT * FROM assets " and then add the necessary constraints. +// For SQL databases, the query will start with "SELECT assets.id, assets.create_at, assets.last_seen, assets.type, assets.content FROM " and then add the necessary constraints. func (as *AssetDB) AssetQuery(constraints string) ([]*types.Asset, error) { return as.repository.AssetQuery(constraints) } // RelationQuery executes a query against the relation table of the db. -// For SQL databases, the query will start with "SELECT * FROM relations " and then add the necessary constraints. +// For SQL databases, the query will start with "SELECT relations.id, relations.create_at, relations.last_seen, relations.type, relations.from_asset_id, relations.to_asset_id FROM " and then add the necessary constraints. func (as *AssetDB) RelationQuery(constraints string) ([]*types.Relation, error) { return as.repository.RelationQuery(constraints) } From eb0a927c8c51e6205219c0621bbff535afc272f3 Mon Sep 17 00:00:00 2001 From: The-Inceptions <83852285+The-Inceptions@users.noreply.github.com> Date: Wed, 17 Jan 2024 11:44:33 -0500 Subject: [PATCH 2/2] Added bools in relation query, allowing selective fill --- assetdb.go | 5 +++-- assetdb_test.go | 34 +++++++++++++++++++++++++++--- repository/repository.go | 2 +- repository/sql.go | 45 ++++++++++++++++++++++++++-------------- 4 files changed, 65 insertions(+), 21 deletions(-) diff --git a/assetdb.go b/assetdb.go index bc421bd..573f563 100644 --- a/assetdb.go +++ b/assetdb.go @@ -116,7 +116,8 @@ func (as *AssetDB) AssetQuery(constraints string) ([]*types.Asset, error) { } // RelationQuery executes a query against the relation table of the db. +// The fillFrom and fillTo parameters determine whether the source and destination assets of the relation should be filled. // For SQL databases, the query will start with "SELECT relations.id, relations.create_at, relations.last_seen, relations.type, relations.from_asset_id, relations.to_asset_id FROM " and then add the necessary constraints. -func (as *AssetDB) RelationQuery(constraints string) ([]*types.Relation, error) { - return as.repository.RelationQuery(constraints) +func (as *AssetDB) RelationQuery(constraints string, fillFrom, fillTo bool) ([]*types.Relation, error) { + return as.repository.RelationQuery(constraints, fillFrom, fillTo) } diff --git a/assetdb_test.go b/assetdb_test.go index d409640..ac2a13e 100644 --- a/assetdb_test.go +++ b/assetdb_test.go @@ -472,7 +472,7 @@ func TestRelationQuery(t *testing.T) { createdAssets := createAssets(db) createdRelations := createRelations(createdAssets, db) - queriedRelations, err := db.RelationQuery("") + queriedRelations, err := db.RelationQuery("", true, true) if err != nil { t.Errorf("%v", err) return @@ -486,6 +486,34 @@ func TestRelationQuery(t *testing.T) { assert.Contains(t, createdAssets, relation.FromAsset) assert.Contains(t, createdAssets, relation.ToAsset) } + + // Verify the relations and assets are not populated in the relation + queriedRelations, err = db.RelationQuery("", false, false) + if err != nil { + t.Errorf("%v", err) + return + } + for k, relation := range queriedRelations { + assert.Equal(t, createdRelations[k].ID, relation.ID) + assert.Equal(t, createdRelations[k].Type, relation.Type) + assert.Equal(t, createdRelations[k].LastSeen, relation.LastSeen) + assert.Equal(t, createdRelations[k].FromAsset.ID, relation.FromAsset.ID) + assert.Equal(t, createdRelations[k].ToAsset.ID, relation.ToAsset.ID) + } + + // Verify the relations and only the from asset is populated in the relation + queriedRelations, err = db.RelationQuery("", true, false) + if err != nil { + t.Errorf("%v", err) + return + } + for k, relation := range queriedRelations { + assert.Equal(t, createdRelations[k].ID, relation.ID) + assert.Equal(t, createdRelations[k].Type, relation.Type) + assert.Equal(t, createdRelations[k].LastSeen, relation.LastSeen) + assert.Contains(t, createdAssets, relation.FromAsset) + assert.Equal(t, createdRelations[k].ToAsset.ID, relation.ToAsset.ID) + } } func createRelations(assets []*types.Asset, db *AssetDB) []*types.Relation { @@ -698,7 +726,7 @@ func (m *mockAssetDB) AssetQuery(query string) ([]*types.Asset, error) { return args.Get(0).([]*types.Asset), args.Error(1) } -func (m *mockAssetDB) RelationQuery(constraints string) ([]*types.Relation, error) { - args := m.Called(constraints) +func (m *mockAssetDB) RelationQuery(constraints string, fillFrom, fillTo bool) ([]*types.Relation, error) { + args := m.Called(constraints, fillFrom, fillTo) return args.Get(0).([]*types.Relation), args.Error(1) } diff --git a/repository/repository.go b/repository/repository.go index ed3ea1c..6455b85 100644 --- a/repository/repository.go +++ b/repository/repository.go @@ -25,5 +25,5 @@ type Repository interface { OutgoingRelations(asset *types.Asset, since time.Time, relationTypes ...string) ([]*types.Relation, error) RawQuery(sqlstr string, results interface{}) error AssetQuery(constraints string) ([]*types.Asset, error) - RelationQuery(constraints string) ([]*types.Relation, error) + RelationQuery(constraints string, fillFrom, fillTo bool) ([]*types.Relation, error) } diff --git a/repository/sql.go b/repository/sql.go index 246dc2d..24935f7 100644 --- a/repository/sql.go +++ b/repository/sql.go @@ -564,7 +564,8 @@ func (sql *sqlRepository) gormAssetToAsset(ga *Asset) (*types.Asset, error) { // RelationQuery creates a query and returns the slice of Relations found. The query will start with: // "SELECT relations.id, relations.create_at, relations.last_seen, relations.type, relations.from_asset_id, relations.to_asset_id FROM " // and then add the provided constraints. The query much include the relations table and remain named relations for parsing. -func (sql *sqlRepository) RelationQuery(constraints string) ([]*types.Relation, error) { +// The fillFrom and fillTo parameters determine whether the source and destination assets of the relation should be filled. +func (sql *sqlRepository) RelationQuery(constraints string, fillFrom, fillTo bool) ([]*types.Relation, error) { var rs []*Relation if constraints == "" { @@ -578,29 +579,43 @@ func (sql *sqlRepository) RelationQuery(constraints string) ([]*types.Relation, var relations []*types.Relation for _, r := range rs { - if relation, err := sql.gormRelationToRelation(r); err == nil { + if relation, err := sql.gormRelationToRelation(r, fillFrom, fillTo); err == nil { relations = append(relations, relation) + } else { + return nil, err } } return relations, nil } -func (sql *sqlRepository) gormRelationToRelation(gr *Relation) (*types.Relation, error) { - fromasset, err := sql.FindAssetById(strconv.FormatUint(gr.FromAssetID, 10), time.Time{}) - if err != nil { - return nil, err - } - toasset, err := sql.FindAssetById(strconv.FormatUint(gr.ToAssetID, 10), time.Time{}) - if err != nil { - return nil, err - } +func (sql *sqlRepository) gormRelationToRelation(gr *Relation, fillFrom, fillTo bool) (*types.Relation, error) { - return &types.Relation{ + relation := &types.Relation{ ID: strconv.FormatUint(gr.ID, 10), CreatedAt: gr.CreatedAt, LastSeen: gr.LastSeen, Type: gr.Type, - FromAsset: fromasset, - ToAsset: toasset, - }, nil + } + + if fillFrom { + if fromasset, err := sql.FindAssetById(strconv.FormatUint(gr.FromAssetID, 10), time.Time{}); err == nil { + relation.FromAsset = fromasset + } else { + return nil, err + } + } else { + relation.FromAsset = &types.Asset{ID: strconv.FormatUint(gr.FromAssetID, 10)} + } + + if fillTo { + if toasset, err := sql.FindAssetById(strconv.FormatUint(gr.ToAssetID, 10), time.Time{}); err == nil { + relation.ToAsset = toasset + } else { + return nil, err + } + } else { + relation.ToAsset = &types.Asset{ID: strconv.FormatUint(gr.ToAssetID, 10)} + } + + return relation, nil }