Skip to content
This repository was archived by the owner on Nov 10, 2020. It is now read-only.

Commit a99417e

Browse files
committed
Replace Trainer.playground with Trainer.swift
Update README.md to reflect improved training time
1 parent e76dc8f commit a99417e

File tree

6 files changed

+66
-60
lines changed

6 files changed

+66
-60
lines changed
Binary file not shown.

Diff for: ProgrammingLanguageClassifier.mlmodel

1.94 KB
Binary file not shown.

Diff for: README.md

+14-4
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,20 @@ $ cd Programming-Language-Classifier
4747
$ git submodule update --init
4848
```
4949

50-
- Open `Trainer.playground` and fill in the placeholder values
51-
for `destinationPath` and `corpusPath`.
52-
- Run the playground and wait for the model to be trained
53-
_(on a 2017 MacBook Pro, this took about an hour)_.
50+
- Open `Trainer.swift` in an editor and fill in the placeholder values
51+
for `destinationPath` and `corpusPath`:
52+
53+
```terminal
54+
$ open ./Trainer.swift
55+
```
56+
57+
- Run `Trainer.swift` and wait for the model to be trained
58+
_(on a 2017 MacBook Pro, this took a few minutes)_:
59+
60+
```terminal
61+
$ swift ./Trainer.swift
62+
```
63+
5464
- Compile the generated `.mlmodel` bundle using the following command:
5565

5666
```terminal

Diff for: Trainer.playground/Sources/ProgrammingLanguage.swift

-35
This file was deleted.

Diff for: Trainer.playground/contents.xcplayground

-4
This file was deleted.

Diff for: Trainer.playground/Contents.swift renamed to Trainer.swift

+52-17
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#!/usr/bin/swift
2+
13
import Foundation
24

35
guard #available(OSX 10.14, *) else {
@@ -7,54 +9,87 @@ guard #available(OSX 10.14, *) else {
79
import CreateML
810
import NaturalLanguage
911

10-
//let destinationPath = <#Path to Destination.mlmodel#>
12+
enum ProgrammingLanguage: String {
13+
case c = "C"
14+
case cPlusPlus = "C++"
15+
case go = "Go"
16+
case java = "Java"
17+
case javaScript = "JavaScript"
18+
case objectiveC = "Objective-C"
19+
case php = "PHP"
20+
case ruby = "Ruby"
21+
case rust = "Rust"
22+
case swift = "Swift"
23+
24+
init?(directory: String, fileExtension: String?) {
25+
switch (directory, fileExtension) {
26+
case ("c", "h"), (_, "c"): self = .c
27+
case ("cc", "h"), (_, "cc"), (_, "cpp"): self = .cPlusPlus
28+
case (_, "go"): self = .go
29+
case (_, "java"): self = .java
30+
case (_, "js"): self = .javaScript
31+
case ("objective-c", "h"), (_, "m"): self = .objectiveC
32+
case (_, "php"): self = .php
33+
case (_, "rb"): self = .ruby
34+
case (_, "rs"): self = .rust
35+
case (_, "swift"): self = .swift
36+
default:
37+
return nil
38+
}
39+
}
40+
}
1141

12-
//let corpusPath = "<#Path to Corpus Directory#>"
42+
let destinationPath = "/Users/mattt/Desktop/Classifier.mlmodel"
43+
44+
let corpusPath = "/Users/mattt/Downloads/code-corpora"
1345
let corpusURL = URL(fileURLWithPath: corpusPath)
1446

1547
let fileManager = FileManager.default
1648

17-
try fileManager.contentsOfDirectory(at: corpusURL, includingPropertiesForKeys: [.isDirectoryKey], options: .skipsHiddenFiles)
18-
1949
do {
20-
var corpus = try MLDataTable(dictionary: ["text": [""], "label": [""]])
21-
50+
var corpus: [(text: String, label: String)] = []
51+
2252
for directory in try fileManager.contentsOfDirectory(at: corpusURL, includingPropertiesForKeys: [.isDirectoryKey], options: [.skipsHiddenFiles]) {
2353
guard directory.hasDirectoryPath,
2454
let enumerator = fileManager.enumerator(at: directory, includingPropertiesForKeys: [.isDirectoryKey])
2555
else {
2656
continue
2757
}
28-
58+
2959
for case let resource as URL in enumerator {
3060
guard !resource.hasDirectoryPath,
3161
let language = ProgrammingLanguage(directory: directory.lastPathComponent, fileExtension: resource.pathExtension),
3262
let text = try? String(contentsOf: resource)
3363
else {
3464
continue
3565
}
36-
37-
let dataTable = try MLDataTable(dictionary: ["text": text, "label": language.description])
38-
corpus.append(contentsOf: dataTable)
66+
corpus.append((text: text, label: language.rawValue))
3967
}
4068
}
41-
42-
let (trainingData, testingData) = corpus.randomSplit(by: 0.9, seed: 0)
43-
69+
70+
let (texts, labels) = corpus.reduce(into: ([String](), [String]())) {
71+
$0.0.append($1.text)
72+
$0.1.append($1.label)
73+
}
74+
75+
let dataTable = try MLDataTable(dictionary: ["text": texts, "label": labels])
76+
77+
let (trainingData, testingData) = dataTable.randomSplit(by: 0.9, seed: 0)
78+
4479
// As of Xcode 10.0 beta (10L176w),
4580
// attempted use of CRF algorithm results in EXC_BAD_ACCESS.
4681
/*
4782
let parameters = MLTextClassifier.ModelParameters(validationData: validationData, algorithm: .crf(revision: 1), language: .english)
4883
let classifier = try MLTextClassifier(trainingData: trainingData, textColumn: "text", labelColumn: "label", parameters: parameters)
4984
*/
50-
85+
5186
let classifier = try MLTextClassifier(trainingData: trainingData, textColumn: "text", labelColumn: "label")
52-
87+
5388
classifier.modelParameters.algorithm
54-
89+
5590
let evaluation = classifier.evaluation(on: testingData)
5691
print(evaluation)
57-
92+
5893
let modelPath = URL(fileURLWithPath: destinationPath)
5994
try classifier.write(to: modelPath)
6095
} catch {

0 commit comments

Comments
 (0)